diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index 5dfdcd0132..a9e794b3f9 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -26,7 +26,11 @@ jobs: windows-2019-cl, ] - build_type: [Debug, Release] + build_type: [ + Debug, + #TODO(Varun) The release build takes over 2.5 hours, need to figure out why. + # Release + ] build_unstable: [ON] include: #TODO This build fails, need to understand why. @@ -90,13 +94,18 @@ jobs: - name: Checkout uses: actions/checkout@v2 - - name: Build + - name: Configuration run: | cmake -E remove_directory build cmake -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib" - cmake --build build --config ${{ matrix.build_type }} --target gtsam - cmake --build build --config ${{ matrix.build_type }} --target gtsam_unstable - cmake --build build --config ${{ matrix.build_type }} --target wrap - cmake --build build --config ${{ matrix.build_type }} --target check.base - cmake --build build --config ${{ matrix.build_type }} --target check.base_unstable - cmake --build build --config ${{ matrix.build_type }} --target check.linear + + - name: Build + run: | + # Since Visual Studio is a multi-generator, we need to use --config + # https://stackoverflow.com/a/24470998/1236990 + cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam + cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable + cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d871424c4..74019da446 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,12 +9,18 @@ endif() # Set the version number for the library set (GTSAM_VERSION_MAJOR 4) -set (GTSAM_VERSION_MINOR 1) -set (GTSAM_VERSION_PATCH 1) +set (GTSAM_VERSION_MINOR 2) +set (GTSAM_VERSION_PATCH 0) +set (GTSAM_PRERELEASE_VERSION "a0") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") -set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") -set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING}) +if (${GTSAM_VERSION_PATCH} EQUAL 0) + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}") +else() + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}") +endif() +message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}") + set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR}) set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH}) @@ -87,6 +93,13 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX) CACHE STRING "The Python version to use for wrapping") # Set the include directory for matlab.h set(GTWRAP_INCLUDE_NAME "wrap") + + # Copy matlab.h to the correct folder. + configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h + ${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY) + # Add the include directories so that matlab.h can be found + include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}") + add_subdirectory(wrap) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake") endif() diff --git a/README.md b/README.md index 0461323016..ee5746e1cf 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ **Important Note** -As of August 1 2020, the `develop` branch is officially in "Pre 4.1" mode, and features deprecated in 4.0 have been removed. Please use the last [4.0.3 release](https://github.com/borglab/gtsam/releases/tag/4.0.3) if you need those features. +As of Dec 2021, the `develop` branch is officially in "Pre 4.2" mode. A great new feature we will be adding in 4.2 is *hybrid inference* a la DCSLAM (Kevin Doherty et al) and we envision several API-breaking changes will happen in the discrete folder. -However, most are easily converted and can be tracked down (in 4.0.3) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4`. +In addition, features deprecated in 4.1 will be removed. Please use the last [4.1.1 release](https://github.com/borglab/gtsam/releases/tag/4.1.1) if you need those features. However, most (not all, unfortunately) are easily converted and can be tracked down (in 4.1.1) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V41`. ## What is GTSAM? diff --git a/examples/DiscreteBayesNetExample.cpp b/examples/DiscreteBayesNetExample.cpp index 5dca116c35..febc1e1288 100644 --- a/examples/DiscreteBayesNetExample.cpp +++ b/examples/DiscreteBayesNetExample.cpp @@ -56,8 +56,8 @@ int main(int argc, char **argv) { DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); // solve - DiscreteFactor::sharedValues mpe = chordal->optimize(); - GTSAM_PRINT(*mpe); + auto mpe = chordal->optimize(); + GTSAM_PRINT(mpe); // We can also build a Bayes tree (directed junction tree). // The elimination order above will do fine: @@ -70,14 +70,14 @@ int main(int argc, char **argv) { // solve again, now with evidence DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - DiscreteFactor::sharedValues mpe2 = chordal2->optimize(); - GTSAM_PRINT(*mpe2); + auto mpe2 = chordal2->optimize(); + GTSAM_PRINT(mpe2); // We can also sample from it cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { - DiscreteFactor::sharedValues sample = chordal2->sample(); - GTSAM_PRINT(*sample); + auto sample = chordal2->sample(); + GTSAM_PRINT(sample); } return 0; } diff --git a/examples/DiscreteBayesNet_FG.cpp b/examples/DiscreteBayesNet_FG.cpp index 121df4befe..69283a1be7 100644 --- a/examples/DiscreteBayesNet_FG.cpp +++ b/examples/DiscreteBayesNet_FG.cpp @@ -33,11 +33,11 @@ using namespace gtsam; int main(int argc, char **argv) { // Define keys and a print function Key C(1), S(2), R(3), W(4); - auto print = [=](DiscreteFactor::sharedValues values) { - cout << boolalpha << "Cloudy = " << static_cast((*values)[C]) - << " Sprinkler = " << static_cast((*values)[S]) - << " Rain = " << boolalpha << static_cast((*values)[R]) - << " WetGrass = " << static_cast((*values)[W]) << endl; + auto print = [=](const DiscreteFactor::Values& values) { + cout << boolalpha << "Cloudy = " << static_cast(values.at(C)) + << " Sprinkler = " << static_cast(values.at(S)) + << " Rain = " << boolalpha << static_cast(values.at(R)) + << " WetGrass = " << static_cast(values.at(W)) << endl; }; // We assume binary state variables @@ -85,7 +85,7 @@ int main(int argc, char **argv) { } // "Most Probable Explanation", i.e., configuration with largest value - DiscreteFactor::sharedValues mpe = graph.eliminateSequential()->optimize(); + auto mpe = graph.eliminateSequential()->optimize(); cout << "\nMost Probable Explanation (MPE):" << endl; print(mpe); @@ -97,7 +97,7 @@ int main(int argc, char **argv) { // solve again, now with evidence DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues mpe_with_evidence = chordal->optimize(); + auto mpe_with_evidence = chordal->optimize(); cout << "\nMPE given C=0:" << endl; print(mpe_with_evidence); @@ -113,7 +113,7 @@ int main(int argc, char **argv) { // We can also sample from it cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { - DiscreteFactor::sharedValues sample = chordal->sample(); + auto sample = chordal->sample(); print(sample); } return 0; diff --git a/examples/FisheyeExample.cpp b/examples/FisheyeExample.cpp index 2231492993..fc0aed0d77 100644 --- a/examples/FisheyeExample.cpp +++ b/examples/FisheyeExample.cpp @@ -122,8 +122,7 @@ int main(int argc, char *argv[]) { std::cout << "initial error=" << graph.error(initialEstimate) << std::endl; std::cout << "final error=" << graph.error(result) << std::endl; - std::ofstream os("examples/vio_batch.dot"); - graph.saveGraph(os, result); + graph.saveGraph("examples/vio_batch.dot", result); return 0; } diff --git a/examples/HMMExample.cpp b/examples/HMMExample.cpp index ee861e3811..b46baf4e09 100644 --- a/examples/HMMExample.cpp +++ b/examples/HMMExample.cpp @@ -66,14 +66,14 @@ int main(int argc, char **argv) { chordal->print("Eliminated"); // solve - DiscreteFactor::sharedValues mpe = chordal->optimize(); - GTSAM_PRINT(*mpe); + auto mpe = chordal->optimize(); + GTSAM_PRINT(mpe); // We can also sample from it cout << "\n10 samples:" << endl; for (size_t k = 0; k < 10; k++) { - DiscreteFactor::sharedValues sample = chordal->sample(); - GTSAM_PRINT(*sample); + auto sample = chordal->sample(); + GTSAM_PRINT(sample); } // Or compute the marginals. This re-eliminates the FG into a Bayes tree diff --git a/examples/Pose2SLAMExample_graphviz.cpp b/examples/Pose2SLAMExample_graphviz.cpp index 27d5567252..a8768e2b8d 100644 --- a/examples/Pose2SLAMExample_graphviz.cpp +++ b/examples/Pose2SLAMExample_graphviz.cpp @@ -60,11 +60,10 @@ int main(int argc, char** argv) { // save factor graph as graphviz dot file // Render to PDF using "fdp Pose2SLAMExample.dot -Tpdf > graph.pdf" - ofstream os("Pose2SLAMExample.dot"); - graph.saveGraph(os, result); + graph.saveGraph("Pose2SLAMExample.dot", result); // Also print out to console - graph.saveGraph(cout, result); + graph.dot(cout, result); return 0; } diff --git a/examples/UGM_chain.cpp b/examples/UGM_chain.cpp index 3a885a844c..ababef0220 100644 --- a/examples/UGM_chain.cpp +++ b/examples/UGM_chain.cpp @@ -70,8 +70,8 @@ int main(int argc, char** argv) { // "Decoding", i.e., configuration with largest value // We use sequential variable elimination DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues optimalDecoding = chordal->optimize(); - optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n"); + auto optimalDecoding = chordal->optimize(); + optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); // "Inference" Computing marginals for each node // Here we'll make use of DiscreteMarginals class, which makes use of diff --git a/examples/UGM_small.cpp b/examples/UGM_small.cpp index 27a6205a39..f4f3f1fd0b 100644 --- a/examples/UGM_small.cpp +++ b/examples/UGM_small.cpp @@ -63,8 +63,8 @@ int main(int argc, char** argv) { // "Decoding", i.e., configuration with largest value (MPE) // We use sequential variable elimination DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues optimalDecoding = chordal->optimize(); - optimalDecoding->print("\noptimalDecoding"); + auto optimalDecoding = chordal->optimize(); + GTSAM_PRINT(optimalDecoding); // "Inference" Computing marginals cout << "\nComputing Node Marginals .." << endl; diff --git a/gtsam/CMakeLists.txt b/gtsam/CMakeLists.txt index 535d60eb18..a293c6ec28 100644 --- a/gtsam/CMakeLists.txt +++ b/gtsam/CMakeLists.txt @@ -15,7 +15,7 @@ set (gtsam_subdirs sam sfm slam - navigation + navigation ) set(gtsam_srcs) diff --git a/gtsam/base/base.i b/gtsam/base/base.i index d9c51fbe83..9838f97d38 100644 --- a/gtsam/base/base.i +++ b/gtsam/base/base.i @@ -38,7 +38,7 @@ class DSFMap { DSFMap(); KEY find(const KEY& key) const; void merge(const KEY& x, const KEY& y); - std::map sets(); + std::map sets(); }; class IndexPairSet { diff --git a/gtsam/base/tests/testMatrix.cpp b/gtsam/base/tests/testMatrix.cpp index a7c2187059..7802f27e1d 100644 --- a/gtsam/base/tests/testMatrix.cpp +++ b/gtsam/base/tests/testMatrix.cpp @@ -173,7 +173,7 @@ TEST(Matrix, stack ) { Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished(); Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished(); - Matrix AB = stack(2, &A, &B); + Matrix AB = gtsam::stack(2, &A, &B); Matrix C(5, 2); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) @@ -187,7 +187,7 @@ TEST(Matrix, stack ) std::vector matrices; matrices.push_back(A); matrices.push_back(B); - Matrix AB2 = stack(matrices); + Matrix AB2 = gtsam::stack(matrices); EQUALITY(C,AB2); } diff --git a/gtsam/basis/basis.i b/gtsam/basis/basis.i index 8f06fd2e13..c9c0274388 100644 --- a/gtsam/basis/basis.i +++ b/gtsam/basis/basis.i @@ -140,7 +140,7 @@ class FitBasis { static gtsam::GaussianFactorGraph::shared_ptr LinearGraph( const std::map& sequence, const gtsam::noiseModel::Base* model, size_t N); - Parameters parameters() const; + This::Parameters parameters() const; }; } // namespace gtsam diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 439889ebfc..f6a64f11fc 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -248,8 +248,9 @@ namespace gtsam { void dot(std::ostream& os, bool showZero) const override { os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ << "\"]\n"; - for (size_t i = 0; i < branches_.size(); i++) { - NodePtr branch = branches_[i]; + size_t B = branches_.size(); + for (size_t i = 0; i < B; i++) { + const NodePtr& branch = branches_[i]; // Check if zero if (!showZero) { @@ -258,8 +259,10 @@ namespace gtsam { } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; - if (i == 0) os << " [style=dashed]"; - if (i > 1) os << " [style=bold]"; + if (B == 2) { + if (i == 0) os << " [style=dashed]"; + if (i > 1) os << " [style=bold]"; + } os << std::endl; branch->dot(os, showZero); } @@ -671,7 +674,14 @@ namespace gtsam { int result = system( ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); -} + } + + template + std::string DecisionTree::dot(bool showZero) const { + std::stringstream ss; + dot(ss, showZero); + return ss.str(); + } /*********************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0ee0b8be0c..0a78d46352 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -19,6 +19,8 @@ #pragma once +#include + #include #include @@ -35,7 +37,7 @@ namespace gtsam { * Y = function range (any algebra), e.g., bool, int, double */ template - class DecisionTree { + class GTSAM_EXPORT DecisionTree { public: @@ -198,6 +200,9 @@ namespace gtsam { /** output to graphviz format, open a file */ void dot(const std::string& name, bool showZero = true) const; + /** output to graphviz format string */ + std::string dot(bool showZero = true) const; + /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index b7b9d70348..7aed00c57d 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -134,5 +134,52 @@ namespace gtsam { return boost::make_shared(dkeys, result); } -/* ************************************************************************* */ + /* ************************************************************************* */ + std::vector> DecisionTreeFactor::enumerate() const { + // Get all possible assignments + std::vector> pairs; + for (auto& key : keys()) { + pairs.emplace_back(key, cardinalities_.at(key)); + } + // Reverse to make cartesianProduct output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = cartesianProduct(rpairs); + + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + + /* ************************************************************************* */ + std::string DecisionTreeFactor::markdown( + const KeyFormatter& keyFormatter) const { + std::stringstream ss; + + // Print out header and construct argument for `cartesianProduct`. + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + auto rows = enumerate(); + for (const auto& kv : rows) { + ss << "|"; + auto assignment = kv.first; + for (auto& key : keys()) ss << assignment.at(key) << "|"; + ss << kv.second << "|\n"; + } + return ss.str(); + } + + /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index d1696a2818..f90af56dd0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -61,6 +61,15 @@ namespace gtsam { DiscreteFactor(keys.indices()), Potentials(keys, table) { } + /// Single-key specialization + template + DecisionTreeFactor(const DiscreteKey& key, SOURCE table) + : DecisionTreeFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) + : DecisionTreeFactor(DiscreteKeys{key}, row) {} + /** Construct from a DiscreteConditional type */ DecisionTreeFactor(const DiscreteConditional& c); @@ -80,7 +89,7 @@ namespace gtsam { /// @{ /// Value is just look up in AlgebraicDecisonTree - double operator()(const Values& values) const override { + double operator()(const DiscreteValues& values) const override { return Potentials::operator()(values); } @@ -162,7 +171,19 @@ namespace gtsam { // Potentials::reduceWithInverse(inverseReduction); // } + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + + /// @} + }; // DecisionTreeFactor diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 84a80c5651..d9fba630e5 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -35,41 +35,41 @@ namespace gtsam { } /* ************************************************************************* */ -// void DiscreteBayesNet::add_front(const Signature& s) { -// push_front(boost::make_shared(s)); -// } - - /* ************************************************************************* */ - void DiscreteBayesNet::add(const Signature& s) { - push_back(boost::make_shared(s)); - } - - /* ************************************************************************* */ - double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const { + double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { // evaluate all conditionals and multiply double result = 1.0; - for(DiscreteConditional::shared_ptr conditional: *this) + for(const DiscreteConditional::shared_ptr& conditional: *this) result *= (*conditional)(values); return result; } /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const { + DiscreteValues DiscreteBayesNet::optimize() const { // solve each node in turn in topological sort order (parents first) - DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); + DiscreteValues result; for (auto conditional: boost::adaptors::reverse(*this)) - conditional->solveInPlace(*result); + conditional->solveInPlace(&result); return result; } /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteBayesNet::sample() const { + DiscreteValues DiscreteBayesNet::sample() const { // sample each node in turn in topological sort order (parents first) - DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); + DiscreteValues result; for (auto conditional: boost::adaptors::reverse(*this)) - conditional->sampleInPlace(*result); + conditional->sampleInPlace(&result); return result; } + /* ************************************************************************* */ + std::string DiscreteBayesNet::markdown( + const KeyFormatter& keyFormatter) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesNet` of size " << size() << endl << endl; + for(const DiscreteConditional::shared_ptr& conditional: *this) + ss << conditional->markdown(keyFormatter) << endl; + return ss.str(); + } /* ************************************************************************* */ } // namespace diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index d5ba30584c..aed4cec0aa 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -13,6 +13,7 @@ * @file DiscreteBayesNet.h * @date Feb 15, 2011 * @author Duy-Nguyen Ta + * @author Frank dellaert */ #pragma once @@ -22,6 +23,7 @@ #include #include #include +#include #include namespace gtsam { @@ -71,24 +73,45 @@ namespace gtsam { /// @name Standard Interface /// @{ - /** Add a DiscreteCondtional */ - void add(const Signature& s); + // Add inherited versions of add. + using Base::add; + + /** Add a DiscretePrior using a table or a string */ + void add(const DiscreteKey& key, const std::string& spec) { + emplace_shared(key, spec); + } -// /** Add a DiscreteCondtional in front, when listing parents first*/ -// GTSAM_EXPORT void add_front(const Signature& s); + /** Add a DiscreteCondtional */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + + //** evaluate for given DiscreteValues */ + double evaluate(const DiscreteValues & values) const; - //** evaluate for given Values */ - double evaluate(const DiscreteConditional::Values & values) const; + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } /** * Solve the DiscreteBayesNet by back-substitution */ - DiscreteFactor::sharedValues optimize() const; + DiscreteValues optimize() const; /** Do ancestral sampling */ - DiscreteFactor::sharedValues sample() const; + DiscreteValues sample() const; ///@} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} private: /** Serialization function */ diff --git a/gtsam/discrete/DiscreteBayesTree.cpp b/gtsam/discrete/DiscreteBayesTree.cpp index 990d10dbe6..8a9186d05a 100644 --- a/gtsam/discrete/DiscreteBayesTree.cpp +++ b/gtsam/discrete/DiscreteBayesTree.cpp @@ -31,7 +31,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteBayesTreeClique::evaluate( - const DiscreteConditional::Values& values) const { + const DiscreteValues& values) const { // evaluate all conditionals and multiply double result = (*conditional_)(values); for (const auto& child : children) { @@ -47,7 +47,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteBayesTree::evaluate( - const DiscreteConditional::Values& values) const { + const DiscreteValues& values) const { double result = 1.0; for (const auto& root : roots_) { result *= root->evaluate(values); @@ -55,8 +55,21 @@ namespace gtsam { return result; } -} // \namespace gtsam - - - + /* **************************************************************************/ + std::string DiscreteBayesTree::markdown( + const KeyFormatter& keyFormatter) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl; + auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, + size_t& indent) { + ss << "\n" << clique->conditional()->markdown(keyFormatter); + return indent + 1; + }; + size_t indent; + treeTraversal::DepthFirstForest(*this, indent, visitor); + return ss.str(); + } + /* **************************************************************************/ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 29da5817e2..12d6017cc3 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -57,8 +57,8 @@ class GTSAM_EXPORT DiscreteBayesTreeClique conditional_->printSignature(s, formatter); } - //** evaluate conditional probability of subtree for given Values */ - double evaluate(const DiscreteConditional::Values& values) const; + //** evaluate conditional probability of subtree for given DiscreteValues */ + double evaluate(const DiscreteValues& values) const; }; /* ************************************************************************* */ @@ -72,14 +72,31 @@ class GTSAM_EXPORT DiscreteBayesTree typedef DiscreteBayesTree This; typedef boost::shared_ptr shared_ptr; + /// @name Standard interface + /// @{ /** Default constructor, creates an empty Bayes tree */ DiscreteBayesTree() {} /** Check equality */ bool equals(const This& other, double tol = 1e-9) const; - //** evaluate probability for given Values */ - double evaluate(const DiscreteConditional::Values& values) const; + //** evaluate probability for given DiscreteValues */ + double evaluate(const DiscreteValues& values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues& values) const { + return evaluate(values); + } + + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} }; } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index e3d187303c..46d5509e06 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -97,32 +97,93 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const { - ADT pFS(*this); - Key j; size_t value; - for(Key key: parents()) { +static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, + const DiscreteValues& parentsValues) { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + DiscreteConditional::ADT adt(conditional); + size_t value; + for (Key j : conditional.parents()) { try { - j = (key); value = parentsValues.at(j); - pFS = pFS.choose(j, value); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (std::out_of_range&) { + parentsValues.print("parentsValues: "); + throw runtime_error("DiscreteConditional::choose: parent value missing"); + }; + } + return adt; +} + +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::choose( + const DiscreteValues& parentsValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + ADT adt(*this); + size_t value; + for (Key j : parents()) { + try { + value = parentsValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. } catch (exception&) { - cout << "Key: " << j << " Value: " << value << endl; parentsValues.print("parentsValues: "); - // pFS.print("pFS: "); throw runtime_error("DiscreteConditional::choose: parent value missing"); }; } - return pFS; + // Convert ADT to factor. + DiscreteKeys discreteKeys; + for (Key j : frontals()) { + discreteKeys.emplace_back(j, this->cardinality(j)); + } + return boost::make_shared(discreteKeys, adt); +} + +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + const DiscreteValues& frontalValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the frontal variables. + ADT adt(*this); + size_t value; + for (Key j : frontals()) { + try { + value = frontalValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (exception&) { + frontalValues.print("frontalValues: "); + throw runtime_error("DiscreteConditional::choose: frontal value missing"); + }; + } + + // Convert ADT to factor. + DiscreteKeys discreteKeys; + for (Key j : parents()) { + discreteKeys.emplace_back(j, this->cardinality(j)); + } + return boost::make_shared(discreteKeys, adt); } /* ******************************************************************************** */ -void DiscreteConditional::solveInPlace(Values& values) const { +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + size_t parent_value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value likelihood can only be invoked on single-variable " + "conditional"); + DiscreteValues values; + values.emplace(keys_[0], parent_value); + return likelihood(values); +} + +/* ******************************************************************************** */ +void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // TODO: Abhijit asks: is this really the fastest way? He thinks it is. - ADT pFS = choose(values); // P(F|S=parentsValues) + ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) // Initialize - Values mpe; + DiscreteValues mpe; double maxP = 0; DiscreteKeys keys; @@ -131,10 +192,10 @@ void DiscreteConditional::solveInPlace(Values& values) const { keys & dk; } // Get all Possible Configurations - vector allPosbValues = cartesianProduct(keys); + const auto allPosbValues = cartesianProduct(keys); // Find the MPE - for(Values& frontalVals: allPosbValues) { + for(const auto& frontalVals: allPosbValues) { double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) // Update MPE solution if better if (pValueS > maxP) { @@ -145,28 +206,28 @@ void DiscreteConditional::solveInPlace(Values& values) const { //set values (inPlace) to mpe for(Key j: frontals()) { - values[j] = mpe[j]; + (*values)[j] = mpe[j]; } } /* ******************************************************************************** */ -void DiscreteConditional::sampleInPlace(Values& values) const { +void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { assert(nrFrontals() == 1); Key j = (firstFrontalKey()); - size_t sampled = sample(values); // Sample variable - values[j] = sampled; // store result in partial solution + size_t sampled = sample(*values); // Sample variable given parents + (*values)[j] = sampled; // store result in partial solution } /* ******************************************************************************** */ -size_t DiscreteConditional::solve(const Values& parentsValues) const { +size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { // TODO: is this really the fastest way? I think it is. - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // Then, find the max over all remaining // TODO, only works for one key now, seems horribly slow this way size_t mpe = 0; - Values frontals; + DiscreteValues frontals; double maxP = 0; assert(nrFrontals() == 1); Key j = (firstFrontalKey()); @@ -183,18 +244,22 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const { } /* ******************************************************************************** */ -size_t DiscreteConditional::sample(const Values& parentsValues) const { +size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way - assert(nrFrontals() == 1); + if (nrFrontals() != 1) { + throw std::invalid_argument( + "DiscreteConditional::sample can only be called on single variable " + "conditionals"); + } Key key = firstFrontalKey(); size_t nj = cardinality(key); vector p(nj); - Values frontals; + DiscreteValues frontals; for (size_t value = 0; value < nj; value++) { frontals[key] = value; p[value] = pFS(frontals); // P(F=value|S=parentsValues) @@ -207,5 +272,91 @@ size_t DiscreteConditional::sample(const Values& parentsValues) const { } /* ******************************************************************************** */ +size_t DiscreteConditional::sample(size_t parent_value) const { + if (nrParents() != 1) + throw std::invalid_argument( + "Single value sample() can only be invoked on single-parent " + "conditional"); + DiscreteValues values; + values.emplace(keys_.back(), parent_value); + return sample(values); +} + +/* ************************************************************************* */ +std::string DiscreteConditional::markdown( + const KeyFormatter& keyFormatter) const { + std::stringstream ss; + + // Print out signature. + ss << " *P("; + bool first = true; + for (Key key : frontals()) { + if (!first) ss << ","; + ss << keyFormatter(key); + first = false; + } + if (nrParents() == 0) { + // We have no parents, call factor method. + ss << ")*:\n" << std::endl; + ss << DecisionTreeFactor::markdown(keyFormatter); + return ss.str(); + } + + // We have parents, continue signature and do custom print. + ss << "|"; + first = true; + for (Key parent : parents()) { + if (!first) ss << ","; + ss << keyFormatter(parent); + first = false; + } + ss << ")*:\n" << std::endl; + + // Print out header and construct argument for `cartesianProduct`. + std::vector> pairs; + ss << "|"; + const_iterator it; + for(Key parent: parents()) { + ss << keyFormatter(parent) << "|"; + pairs.emplace_back(parent, cardinalities_.at(parent)); + } + + size_t n = 1; + for(Key key: frontals()) { + size_t k = cardinalities_.at(key); + pairs.emplace_back(key, k); + n *= k; + } + std::vector> slatnorf(pairs.rbegin(), + pairs.rend() - nrParents()); + const auto frontal_assignments = cartesianProduct(slatnorf); + for (const auto& a : frontal_assignments) { + for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it); + ss << "|"; + } + ss << "\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|"; + ss << "\n"; + + // Print out all rows. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = cartesianProduct(rpairs); + size_t count = 0; + for (const auto& a : assignments) { + if (count == 0) { + ss << "|"; + for (it = beginParents(); it != endParents(); ++it) + ss << a.at(*it) << "|"; + } + ss << operator()(a) << "|"; + count = (count + 1) % n; + if (count == 0) ss << "\n"; + } + return ss.str(); +} +/* ************************************************************************* */ -}// namespace +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 8299fab2cf..d21e3ae264 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -42,10 +42,7 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class typedef Conditional BaseConditional; ///< Typedef to our conditional base class - /** A map from keys to values.. - * TODO: Again, do we need this??? */ - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + using Values = DiscreteValues; ///< backwards compatibility /// @name Standard Constructors /// @{ @@ -60,6 +57,34 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, /** Construct from signature */ DiscreteConditional(const Signature& signature); + /** + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const Signature::Table& table) + : DiscreteConditional(Signature(key, parents, table)) {} + + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The string is parsed into a Signature::Table. + * + * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : DiscreteConditional(Signature(key, parents, spec)) {} + + /// No-parent specialization; can also use DiscretePrior. + DiscreteConditional(const DiscreteKey& key, const std::string& spec) + : DiscreteConditional(Signature(key, {}, spec)) {} + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); @@ -102,7 +127,7 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, } /// Evaluate, just look up in AlgebraicDecisonTree - double operator()(const Values& values) const override { + double operator()(const DiscreteValues& values) const override { return Potentials::operator()(values); } @@ -111,35 +136,54 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); } - /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ - ADT choose(const Assignment& parentsValues) const; + /** Restrict to given parent values, returns DecisionTreeFactor */ + DecisionTreeFactor::shared_ptr choose( + const DiscreteValues& parentsValues) const; + + /** Convert to a likelihood factor by providing value before bar. */ + DecisionTreeFactor::shared_ptr likelihood( + const DiscreteValues& frontalValues) const; + + /** Single variable version of likelihood. */ + DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; /** * solve a conditional * @param parentsValues Known values of the parents * @return MPE value of the child (1 frontal variable). */ - size_t solve(const Values& parentsValues) const; + size_t solve(const DiscreteValues& parentsValues) const; /** * sample * @param parentsValues Known values of the parents * @return sample from conditional */ - size_t sample(const Values& parentsValues) const; + size_t sample(const DiscreteValues& parentsValues) const; + + + /// Single value version. + size_t sample(size_t parent_value) const; /// @} /// @name Advanced Interface /// @{ /// solve a conditional, in place - void solveInPlace(Values& parentsValues) const; + void solveInPlace(DiscreteValues* parentsValues) const; /// sample in place, stores result in partial solution - void sampleInPlace(Values& parentsValues) const; + void sampleInPlace(DiscreteValues* parentsValues) const; /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + /// @} }; // DiscreteConditional diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 6b0919507b..d7deca3838 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -18,7 +18,7 @@ #pragma once -#include +#include #include #include @@ -40,18 +40,7 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class typedef Factor Base; ///< Our base class - /** A map from keys to values - * TODO: Do we need this? Should we just use gtsam::Values? - * We just need another special DiscreteValue to represent labels, - * However, all other Lie's operators are undefined in this class. - * The good thing is we can have a Hybrid graph of discrete/continuous variables - * together.. - * Another good thing is we don't need to have the special DiscreteKey which stores - * cardinality of a Discrete variable. It should be handled naturally in - * the new class DiscreteValue, as the varible's type (domain) - */ - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + using Values = DiscreteValues; ///< backwards compatibility public: @@ -92,19 +81,26 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { /// @{ /// Find value for given assignment of values to variables - virtual double operator()(const Values&) const = 0; + virtual double operator()(const DiscreteValues&) const = 0; /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + virtual std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const = 0; + /// @} }; // DiscreteFactor // traits template<> struct traits : public Testable {}; -template<> struct traits : public Testable {}; }// namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index e41968d6b4..bd84e13647 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -56,7 +56,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteFactorGraph::operator()( - const DiscreteFactor::Values &values) const { + const DiscreteValues &values) const { double product = 1.0; for( const sharedFactor& factor: factors_ ) product *= (*factor)(values); @@ -94,7 +94,7 @@ namespace gtsam { // } /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const + DiscreteValues DiscreteFactorGraph::optimize() const { gttic(DiscreteFactorGraph_optimize); return BaseEliminateable::eliminateSequential()->optimize(); @@ -129,6 +129,18 @@ namespace gtsam { return std::make_pair(cond, sum); } -/* ************************************************************************* */ -} // namespace + /* ************************************************************************* */ + std::string DiscreteFactorGraph::markdown( + const KeyFormatter& keyFormatter) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteFactorGraph` of size " << size() << endl << endl; + for (size_t i = 0; i < factors_.size(); i++) { + ss << "factor " << i << ":\n"; + ss << factors_[i]->markdown(keyFormatter) << endl; + } + return ss.str(); + } + /* ************************************************************************* */ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index f39adc9a86..6856493f7f 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -71,10 +71,10 @@ public EliminateableFactorGraph { typedef EliminateableFactorGraph BaseEliminateable; ///< Typedef to base elimination class typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class + using Values = DiscreteValues; ///< backwards compatibility + /** A map from keys to values */ typedef KeyVector Indices; - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; /** Default constructor */ DiscreteFactorGraph() {} @@ -101,35 +101,23 @@ public EliminateableFactorGraph { /// @} - template - void add(const DiscreteKey& j, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j); - push_back(boost::make_shared(keys, table)); - } - - template - void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j1); - keys.push_back(j2); - push_back(boost::make_shared(keys, table)); - } - - /** add shared discreteFactor immediately from arguments */ - template - void add(const DiscreteKeys& keys, SOURCE table) { - push_back(boost::make_shared(keys, table)); + /** Add a decision-tree factor */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); } - + /** Return the set of variables involved in the factors (set union) */ KeySet keys() const; /** return product of all factors as a single factor */ DecisionTreeFactor product() const; - /** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/ - double operator()(const DiscreteFactor::Values & values) const; + /** + * Evaluates the factor graph given values, returns the joint probability of + * the factor graph given specific instantiation of values + */ + double operator()(const DiscreteValues& values) const; /// print void print( @@ -140,7 +128,7 @@ public EliminateableFactorGraph { * the dense elimination function specified in \c function, * followed by back-substitution resulting from elimination. Is equivalent * to calling graph.eliminateSequential()->optimize(). */ - DiscreteFactor::sharedValues optimize() const; + DiscreteValues optimize() const; // /** Permute the variables in the factors */ @@ -149,6 +137,14 @@ public EliminateableFactorGraph { // /** Apply a reduction, which is a remapping of variable indices. */ // GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} }; // \ DiscreteFactorGraph /// traits diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index c041c7e8ed..ae4dac38fc 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -31,19 +31,19 @@ namespace gtsam { * Key type for discrete conditionals * Includes name and cardinality */ - typedef std::pair DiscreteKey; + using DiscreteKey = std::pair; /// DiscreteKeys is a set of keys that can be assembled using the & operator - struct DiscreteKeys: public std::vector { + struct GTSAM_EXPORT DiscreteKeys: public std::vector { - /// Default constructor - DiscreteKeys() { - } + // Forward all constructors. + using std::vector::vector; + + /// Constructor for serialization + DiscreteKeys() : std::vector::vector() {} /// Construct from a key - DiscreteKeys(const DiscreteKey& key) { - push_back(key); - } + explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } /// Construct from a vector of keys DiscreteKeys(const std::vector& keys) : @@ -51,13 +51,13 @@ namespace gtsam { } /// Construct from cardinalities with default names - GTSAM_EXPORT DiscreteKeys(const std::vector& cs); + DiscreteKeys(const std::vector& cs); /// Return a vector of indices - GTSAM_EXPORT KeyVector indices() const; + KeyVector indices() const; /// Return a map from index to cardinality - GTSAM_EXPORT std::map cardinalities() const; + std::map cardinalities() const; /// Add a key (non-const!) DiscreteKeys& operator&(const DiscreteKey& key) { @@ -67,5 +67,5 @@ namespace gtsam { }; // DiscreteKeys /// Create a list from two keys - GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); + DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); } diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index c2a188e086..27352a2110 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -29,7 +29,7 @@ namespace gtsam { /** * A class for computing marginals of variables in a DiscreteFactorGraph */ - class DiscreteMarginals { +class GTSAM_EXPORT DiscreteMarginals { protected: @@ -64,7 +64,7 @@ namespace gtsam { //Create result Vector vResult(key.second); for (size_t state = 0; state < key.second ; ++ state) { - DiscreteFactor::Values values; + DiscreteValues values; values[key.first] = state; vResult(state) = (*marginalFactor)(values); } diff --git a/gtsam/discrete/DiscretePrior.cpp b/gtsam/discrete/DiscretePrior.cpp new file mode 100644 index 0000000000..3941e0199e --- /dev/null +++ b/gtsam/discrete/DiscretePrior.cpp @@ -0,0 +1,50 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscretePrior.cpp + * @date December 2021 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +void DiscretePrior::print(const std::string& s, + const KeyFormatter& formatter) const { + Base::print(s, formatter); +} + +double DiscretePrior::operator()(size_t value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value operator can only be invoked on single-variable " + "priors"); + DiscreteValues values; + values.emplace(keys_[0], value); + return Base::operator()(values); +} + +std::vector DiscretePrior::pmf() const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "DiscretePrior::pmf only defined for single-variable priors"); + const size_t nrValues = cardinalities_.at(keys_[0]); + std::vector array; + array.reserve(nrValues); + for (size_t v = 0; v < nrValues; v++) { + array.push_back(operator()(v)); + } + return array; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h new file mode 100644 index 0000000000..1a7c6ae6cb --- /dev/null +++ b/gtsam/discrete/DiscretePrior.h @@ -0,0 +1,111 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscretePrior.h + * @date December 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include + +#include + +namespace gtsam { + +/** + * A prior probability on a set of discrete variables. + * Derives from DiscreteConditional + */ +class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { + public: + using Base = DiscreteConditional; + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + DiscretePrior() {} + + /// Constructor from factor. + DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {} + + /** + * Construct from a Signature. + * + * Example: DiscretePrior P(D % "3/2"); + */ + DiscretePrior(const Signature& s) : Base(s) {} + + /** + * Construct from key and a Signature::Table specifying the + * conditional probability table (CPT). + * + * Example: DiscretePrior P(D, table); + */ + DiscretePrior(const DiscreteKey& key, const Signature::Table& table) + : Base(Signature(key, {}, table)) {} + + /** + * Construct from key and a string specifying the conditional + * probability table (CPT). + * + * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); + */ + DiscretePrior(const DiscreteKey& key, const std::string& spec) + : DiscretePrior(Signature(key, {}, spec)) {} + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Prior: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// @} + /// @name Standard interface + /// @{ + + /// Evaluate given a single value. + double operator()(size_t value) const; + + /// We also want to keep the Base version, taking DiscreteValues: + // TODO(dellaert): does not play well with wrapper! + // using Base::operator(); + + /// Return entire probability mass function. + std::vector pmf() const; + + /** + * solve a conditional + * @return MPE value of the child (1 frontal variable). + */ + size_t solve() const { return Base::solve({}); } + + /** + * sample + * @return sample from conditional + */ + size_t sample() const { return Base::sample({}); } + + /// @} +}; +// DiscretePrior + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h new file mode 100644 index 0000000000..2d9c8d3cfb --- /dev/null +++ b/gtsam/discrete/DiscreteValues.h @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteValues.h + * @date Dec 13, 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include + +namespace gtsam { + +/** A map from keys to values + * TODO(dellaert): Do we need this? Should we just use gtsam::DiscreteValues? + * We just need another special DiscreteValue to represent labels, + * However, all other Lie's operators are undefined in this class. + * The good thing is we can have a Hybrid graph of discrete/continuous variables + * together.. + * Another good thing is we don't need to have the special DiscreteKey which + * stores cardinality of a Discrete variable. It should be handled naturally in + * the new class DiscreteValue, as the variable's type (domain) + */ +class DiscreteValues : public Assignment { + public: + using Assignment::Assignment; // all constructors + + // Define the implicit default constructor. + DiscreteValues() = default; + + // Construct from assignment. + DiscreteValues(const Assignment& a) : Assignment(a) {} + + void print(const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + std::cout << s << ": "; + for (const typename Assignment::value_type& keyValue : *this) + std::cout << "(" << keyFormatter(keyValue.first) << ", " + << keyValue.second << ")"; + std::cout << std::endl; + } +}; + +// traits +template<> struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp index 331a76c13a..fa491eba36 100644 --- a/gtsam/discrete/Potentials.cpp +++ b/gtsam/discrete/Potentials.cpp @@ -26,10 +26,6 @@ using namespace std; namespace gtsam { -// explicit instantiation -template class DecisionTree; -template class AlgebraicDecisionTree; - /* ************************************************************************* */ double Potentials::safe_div(const double& a, const double& b) { // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); diff --git a/gtsam/discrete/Potentials.h b/gtsam/discrete/Potentials.h index 1078b4c617..856b928168 100644 --- a/gtsam/discrete/Potentials.h +++ b/gtsam/discrete/Potentials.h @@ -29,7 +29,7 @@ namespace gtsam { /** * A base class for both DiscreteFactor and DiscreteConditional */ - class Potentials: public AlgebraicDecisionTree { + class GTSAM_EXPORT Potentials: public AlgebraicDecisionTree { public: @@ -46,7 +46,7 @@ namespace gtsam { } // Safe division for probabilities - GTSAM_EXPORT static double safe_div(const double& a, const double& b); + static double safe_div(const double& a, const double& b); // // Apply either a permutation or a reduction // template @@ -55,10 +55,10 @@ namespace gtsam { public: /** Default constructor for I/O */ - GTSAM_EXPORT Potentials(); + Potentials(); /** Constructor from Indices and ADT */ - GTSAM_EXPORT Potentials(const DiscreteKeys& keys, const ADT& decisionTree); + Potentials(const DiscreteKeys& keys, const ADT& decisionTree); /** Constructor from Indices and (string or doubles) */ template @@ -67,8 +67,8 @@ namespace gtsam { } // Testable - GTSAM_EXPORT bool equals(const Potentials& other, double tol = 1e-9) const; - GTSAM_EXPORT void print(const std::string& s = "Potentials: ", + bool equals(const Potentials& other, double tol = 1e-9) const; + void print(const std::string& s = "Potentials: ", const KeyFormatter& formatter = DefaultKeyFormatter) const; size_t cardinality(Key j) const { return cardinalities_.at(j);} diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index 94b160a291..146555898b 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -38,19 +38,7 @@ namespace gtsam { using boost::phoenix::push_back; // Special rows, true and false - Signature::Row createF() { - Signature::Row r(2); - r[0] = 1; - r[1] = 0; - return r; - } - Signature::Row createT() { - Signature::Row r(2); - r[0] = 0; - r[1] = 1; - return r; - } - Signature::Row T = createT(), F = createF(); + Signature::Row F{1, 0}, T{0, 1}; // Special tables (inefficient, but do we care for user input?) Signature::Table logic(bool ff, bool ft, bool tf, bool tt) { @@ -69,40 +57,13 @@ namespace gtsam { table = or_ | and_ | rows; or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)]; and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)]; - rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42 + rows = +(row | true_ | false_); row = qi::double_ >> +("/" >> qi::double_); true_ = qi::lit("T")[qi::_val = T]; false_ = qi::lit("F")[qi::_val = F]; } } grammar; - // Create simpler parsing function to avoid the issue of only parsing a single row - bool parse_table(const string& spec, Signature::Table& table) { - // check for OR, AND on whole phrase - It f = spec.begin(), l = spec.end(); - if (qi::parse(f, l, - qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) || - qi::parse(f, l, - qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)])) - return true; - - // tokenize into separate rows - istringstream iss(spec); - string token; - while (iss >> token) { - Signature::Row values; - It tf = token.begin(), tl = token.end(); - bool r = qi::parse(tf, tl, - qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) | - qi::lit("T")[ph::ref(values) = T] | - qi::lit("F")[ph::ref(values) = F] ); - if (!r) - return false; - table.push_back(values); - } - - return true; - } } // \namespace parser ostream& operator <<(ostream &os, const Signature::Row &row) { @@ -118,6 +79,18 @@ namespace gtsam { return os; } + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table) + : key_(key), parents_(parents) { + operator=(table); + } + + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : key_(key), parents_(parents) { + operator=(spec); + } + Signature::Signature(const DiscreteKey& key) : key_(key) { } @@ -166,14 +139,11 @@ namespace gtsam { Signature& Signature::operator=(const string& spec) { spec_.reset(spec); Table table; - // NOTE: using simpler parse function to ensure boost back compatibility -// parser::It f = spec.begin(), l = spec.end(); - bool success = // -// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar - parser::parse_table(spec, table); + parser::It f = spec.begin(), l = spec.end(); + bool success = + qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); if (success) { - for(Row& row: table) - normalize(row); + for (Row& row : table) normalize(row); table_.reset(table); } return *this; diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 6c59b5bffa..ff83caa534 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -30,7 +30,7 @@ namespace gtsam { * The format is (Key % string) for nodes with no parents, * and (Key | Key, Key = string) for nodes with parents. * - * The string specifies a conditional probability spec in the 00 01 10 11 order. + * The string specifies a conditional probability table in 00 01 10 11 order. * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc... * * For example, given the following keys @@ -45,9 +45,9 @@ namespace gtsam { * T|A = "99/1 95/5" * L|S = "99/1 90/10" * B|S = "70/30 40/60" - * E|T,L = "F F F 1" + * (E|T,L) = "F F F 1" * X|E = "95/5 2/98" - * D|E,B = "9/1 2/8 3/7 1/9" + * (D|E,B) = "9/1 2/8 3/7 1/9" */ class GTSAM_EXPORT Signature { @@ -72,45 +72,73 @@ namespace gtsam { boost::optional table_; public: - - /** Constructor from DiscreteKey */ - Signature(const DiscreteKey& key); - - /** the variable key */ - const DiscreteKey& key() const { - return key_; - } - - /** the parent keys */ - const DiscreteKeys& parents() const { - return parents_; - } - - /** All keys, with variable key first */ - DiscreteKeys discreteKeys() const; - - /** All key indices, with variable key first */ - KeyVector indices() const; - - // the CPT as parsed, if successful - const boost::optional
& table() const { - return table_; - } - - // the CPT as a vector of doubles, with key's values most rapidly changing - std::vector cpt() const; - - /** Add a parent */ - Signature& operator,(const DiscreteKey& parent); - - /** Add the CPT spec - Fails in boost 1.40 */ - Signature& operator=(const std::string& spec); - - /** Add the CPT spec directly as a table */ - Signature& operator=(const Table& table); - - /** provide streaming */ - GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s); + /** + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. + * + * Example: + * Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + * Signature sig(D, {E, B}, table); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table); + + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example (same CPT as above): + * Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec); + + /** + * Construct from a single DiscreteKey. + * + * The resulting signature has no parents or CPT table. Typical use then + * either adds parents with | and , operators below, or assigns a table with + * operator=(). + */ + Signature(const DiscreteKey& key); + + /** the variable key */ + const DiscreteKey& key() const { return key_; } + + /** the parent keys */ + const DiscreteKeys& parents() const { return parents_; } + + /** All keys, with variable key first */ + DiscreteKeys discreteKeys() const; + + /** All key indices, with variable key first */ + KeyVector indices() const; + + // the CPT as parsed, if successful + const boost::optional
& table() const { return table_; } + + // the CPT as a vector of doubles, with key's values most rapidly changing + std::vector cpt() const; + + /** Add a parent */ + Signature& operator,(const DiscreteKey& parent); + + /** Add the CPT spec */ + Signature& operator=(const std::string& spec); + + /** Add the CPT spec directly as a table */ + Signature& operator=(const Table& table); + + /** provide streaming */ + GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os, + const Signature& s); }; /** @@ -122,7 +150,6 @@ namespace gtsam { /** * Helper function to create Signature objects * example: Signature s(D % "99/1"); - * Uses string parser, which requires BOOST 1.42 or higher */ GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i new file mode 100644 index 0000000000..36caccfc83 --- /dev/null +++ b/gtsam/discrete/discrete.i @@ -0,0 +1,214 @@ +//************************************************************************* +// discrete +//************************************************************************* + +namespace gtsam { + + +#include +class DiscreteKey {}; + +class DiscreteKeys { + DiscreteKeys(); + size_t size() const; + bool empty() const; + gtsam::DiscreteKey at(size_t n) const; + void push_back(const gtsam::DiscreteKey& point_pair); +}; + +// DiscreteValues is added in specializations/discrete.h as a std::map + +#include +class DiscreteFactor { + void print(string s = "DiscreteFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; + bool empty() const; + size_t size() const; + double operator()(const gtsam::DiscreteValues& values) const; +}; + +#include +virtual class DecisionTreeFactor : gtsam::DiscreteFactor { + DecisionTreeFactor(); + + DecisionTreeFactor(const gtsam::DiscreteKey& key, + const std::vector& spec); + DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); + + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const std::vector& keys, string table); + + DecisionTreeFactor(const gtsam::DiscreteConditional& c); + + void print(string s = "DecisionTreeFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + string dot(bool showZero = false) const; + std::vector> enumerate() const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +virtual class DiscreteConditional : gtsam::DecisionTreeFactor { + DiscreteConditional(); + DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal, + const gtsam::Ordering& orderedKeys); + void print(string s = "Discrete Conditional\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + void printSignature( + string s = "Discrete Conditional: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + gtsam::DecisionTreeFactor* toFactor() const; + gtsam::DecisionTreeFactor* choose( + const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* likelihood( + const gtsam::DiscreteValues& frontalValues) const; + gtsam::DecisionTreeFactor* likelihood(size_t value) const; + size_t solve(const gtsam::DiscreteValues& parentsValues) const; + size_t sample(const gtsam::DiscreteValues& parentsValues) const; + size_t sample(size_t value) const; + void solveInPlace(gtsam::DiscreteValues @parentsValues) const; + void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +virtual class DiscretePrior : gtsam::DiscreteConditional { + DiscretePrior(); + DiscretePrior(const gtsam::DecisionTreeFactor& f); + DiscretePrior(const gtsam::DiscreteKey& key, string spec); + void print(string s = "Discrete Prior\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + double operator()(size_t value) const; + std::vector pmf() const; + size_t solve() const; + size_t sample() const; +}; + +#include +class DiscreteBayesNet { + DiscreteBayesNet(); + void add(const gtsam::DiscreteConditional& s); + void add(const gtsam::DiscreteKey& key, string spec); + void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, + string spec); + void add(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteConditional* at(size_t i) const; + void print(string s = "DiscreteBayesNet\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues optimize() const; + gtsam::DiscreteValues sample() const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +class DiscreteBayesTreeClique { + DiscreteBayesTreeClique(); + DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); + const gtsam::DiscreteConditional* conditional() const; + bool isRoot() const; + void printSignature( + const string& s = "Clique: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + double evaluate(const gtsam::DiscreteValues& values) const; +}; + +class DiscreteBayesTree { + DiscreteBayesTree(); + void print(string s = "DiscreteBayesTree\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + + size_t size() const; + bool empty() const; + const DiscreteBayesTreeClique* operator[](size_t j) const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + double operator()(const gtsam::DiscreteValues& values) const; + + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +class DotWriter { + DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, + bool plotFactorPoints = true, bool connectKeysToFactor = true, + bool binaryEdges = true); +}; + +#include +class DiscreteFactorGraph { + DiscreteFactorGraph(); + DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); + + void add(const gtsam::DiscreteKey& j, string table); + void add(const gtsam::DiscreteKey& j, const std::vector& spec); + + void add(const gtsam::DiscreteKeys& keys, string table); + void add(const std::vector& keys, string table); + + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteFactor* at(size_t i) const; + + void print(string s = "") const; + bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const; + + gtsam::DecisionTreeFactor product() const; + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues optimize() const; + + gtsam::DiscreteBayesNet eliminateSequential(); + gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesTree eliminateMultifrontal(); + gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); + + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +} // namespace gtsam diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index be720dbca4..7a33810c7d 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -18,6 +18,7 @@ #include #include // make sure we have traits +#include // headers first to make sure no missing headers //#define DT_NO_PRUNING #include @@ -445,7 +446,7 @@ TEST(ADT, equality_parser) TEST(ADT, constructor) { DiscreteKey v0(0,2), v1(1,3); - Assignment x00, x01, x02, x10, x11, x12; + DiscreteValues x00, x01, x02, x10, x11, x12; x00[0] = 0, x00[1] = 0; x01[0] = 0, x01[1] = 1; x02[0] = 0, x02[1] = 2; @@ -475,7 +476,7 @@ TEST(ADT, constructor) for(double& t: table) t = x++; ADT f3(z0 & z1 & z2 & z3, table); - Assignment assignment; + DiscreteValues assignment; assignment[0] = 0; assignment[1] = 0; assignment[2] = 0; @@ -501,7 +502,7 @@ TEST(ADT, conversion) // f2.print("f2"); dot(fIndexKey, "conversion-f2"); - Assignment x00, x01, x02, x10, x11, x12; + DiscreteValues x00, x01, x02, x10, x11, x12; x00[5] = 0, x00[2] = 0; x01[5] = 0, x01[2] = 1; x10[5] = 1, x10[2] = 0; @@ -577,7 +578,7 @@ TEST(ADT, zero) ADT notb(B, 1, 0); ADT anotb = a * notb; // GTSAM_PRINT(anotb); - Assignment x00, x01, x10, x11; + DiscreteValues x00, x01, x10, x11; x00[0] = 0, x00[1] = 0; x01[0] = 0, x01[1] = 1; x10[0] = 1, x10[1] = 0; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index dd630a284a..6af7ca7313 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -30,20 +30,18 @@ using namespace gtsam; /* ************************************************************************* */ TEST( DecisionTreeFactor, constructors) { + // Declare a bunch of keys DiscreteKey X(0,2), Y(1,3), Z(2,2); - DecisionTreeFactor f1(X, "2 8"); + // Create factors + DecisionTreeFactor f1(X, {2, 8}); DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7"); DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); EXPECT_LONGS_EQUAL(1,f1.size()); EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(3,f3.size()); - // f1.print("f1:"); - // f2.print("f2:"); - // f3.print("f3:"); - - DecisionTreeFactor::Values values; + DiscreteValues values; values[0] = 1; // x values[1] = 2; // y values[2] = 1; // z @@ -55,37 +53,26 @@ TEST( DecisionTreeFactor, constructors) /* ************************************************************************* */ TEST_UNSAFE( DecisionTreeFactor, multiplication) { - // Declare a bunch of keys DiscreteKey v0(0,2), v1(1,2), v2(2,2); - // Create a factor DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); -// f1.print("f1:"); -// f2.print("f2:"); DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); DecisionTreeFactor actual = f1 * f2; -// actual.print("actual: "); CHECK(assert_equal(expected, actual)); } /* ************************************************************************* */ TEST( DecisionTreeFactor, sum_max) { - // Declare a bunch of keys DiscreteKey v0(0,3), v1(1,2); - - // Create a factor DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor expected(v1, "9 12"); DecisionTreeFactor::shared_ptr actual = f1.sum(1); CHECK(assert_equal(expected, *actual, 1e-5)); -// f1.print("f1:"); -// actual->print("actual: "); -// actual->printCache("actual cache: "); DecisionTreeFactor expected2(v1, "5 6"); DecisionTreeFactor::shared_ptr actual2 = f1.max(1); @@ -93,9 +80,43 @@ TEST( DecisionTreeFactor, sum_max) DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); -// f2.print("f2: "); -// actual22->print("actual22: "); +} +/* ************************************************************************* */ +// Check enumerate yields the correct list of assignment/value pairs. +TEST(DecisionTreeFactor, enumerate) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + auto actual = f.enumerate(); + std::vector> expected; + DiscreteValues values; + for (size_t a : {0, 1, 2}) { + for (size_t b : {0, 1}) { + values[12] = a; + values[5] = b; + expected.emplace_back(values, f(values)); + } + } + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DecisionTreeFactor, markdown) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|1|\n" + "|0|1|2|\n" + "|1|0|3|\n" + "|1|1|4|\n" + "|2|0|5|\n" + "|2|1|6|\n"; + auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; + string actual = f.markdown(formatter); + EXPECT(actual == expected); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 2b440e5a0c..1de45905a6 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -38,6 +38,9 @@ using namespace boost::assign; using namespace std; using namespace gtsam; +static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), + LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); + /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { DiscreteBayesNet bayesNet; @@ -71,11 +74,9 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { DiscreteBayesNet asia; - DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), - Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); - asia.add(Asia % "99/1"); - asia.add(Smoking % "50/50"); + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version asia.add(Tuberculosis | Asia = "99/1 95/5"); asia.add(LungCancer | Smoking = "99/1 90/10"); @@ -104,12 +105,12 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(expected2, *chordal->back())); // solve - DiscreteFactor::sharedValues actualMPE = chordal->optimize(); - DiscreteFactor::Values expectedMPE; + auto actualMPE = chordal->optimize(); + DiscreteValues expectedMPE; insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)( Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)( LungCancer.first, 0)(Bronchitis.first, 0); - EXPECT(assert_equal(expectedMPE, *actualMPE)); + EXPECT(assert_equal(expectedMPE, actualMPE)); // add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1"); @@ -117,25 +118,25 @@ TEST(DiscreteBayesNet, Asia) { // solve again, now with evidence DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize(); - DiscreteFactor::Values expectedMPE2; + auto actualMPE2 = chordal2->optimize(); + DiscreteValues expectedMPE2; insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)( Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)( LungCancer.first, 0)(Bronchitis.first, 1); - EXPECT(assert_equal(expectedMPE2, *actualMPE2)); + EXPECT(assert_equal(expectedMPE2, actualMPE2)); // now sample from it - DiscreteFactor::Values expectedSample; + DiscreteValues expectedSample; SETDEBUG("DiscreteConditional::sample", false); insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)( Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)( LungCancer.first, 1)(Bronchitis.first, 0); - DiscreteFactor::sharedValues actualSample = chordal2->sample(); - EXPECT(assert_equal(expectedSample, *actualSample)); + auto actualSample = chordal2->sample(); + EXPECT(assert_equal(expectedSample, actualSample)); } /* ************************************************************************* */ -TEST_UNSAFE(DiscreteBayesNet, Sugar) { +TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); DiscreteBayesNet bn; @@ -149,6 +150,52 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) { bn.add(C | S = "1/1/2 5/2/3"); } +/* ************************************************************************* */ +TEST(DiscreteBayesNet, Dot) { + DiscreteBayesNet fragment; + fragment.add(Asia % "99/1"); + fragment.add(Smoking % "50/50"); + + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + string actual = fragment.dot(); + EXPECT(actual == + "digraph G{\n" + "0->3\n" + "4->6\n" + "3->5\n" + "6->5\n" + "}"); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteBayesNet, markdown) { + DiscreteBayesNet fragment; + fragment.add(Asia % "99/1"); + fragment.add(Smoking | Asia = "8/2 7/3"); + + string expected = + "`DiscreteBayesNet` of size 2\n" + "\n" + " *P(Asia)*:\n\n" + "|Asia|value|\n" + "|:-:|:-:|\n" + "|0|0.99|\n" + "|1|0.01|\n" + "\n" + " *P(Smoking|Asia)*:\n\n" + "|Asia|0|1|\n" + "|:-:|:-:|:-:|\n" + "|0|0.8|0.2|\n" + "|1|0.7|0.3|\n\n"; + auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; + string actual = fragment.markdown(formatter); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index ecf485036a..edb5ea46c6 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -26,88 +26,101 @@ using namespace boost::assign; #include +#include #include using namespace std; using namespace gtsam; - -static bool debug = false; +static constexpr bool debug = false; /* ************************************************************************* */ +struct TestFixture { + vector keys; + DiscreteBayesNet bayesNet; + boost::shared_ptr bayesTree; + + /** + * Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student), + * and then create the Bayes tree from it. + */ + TestFixture() { + // Define variables. + for (int i = 0; i < 15; i++) { + DiscreteKey key_i(i, 2); + keys.push_back(key_i); + } -TEST_UNSAFE(DiscreteBayesTree, ThinTree) { - const int nrNodes = 15; - const size_t nrStates = 2; + // Create thin-tree Bayesnet. + bayesNet.add(keys[14] % "1/3"); - // define variables - vector key; - for (int i = 0; i < nrNodes; i++) { - DiscreteKey key_i(i, nrStates); - key.push_back(key_i); - } + bayesNet.add(keys[13] | keys[14] = "1/3 3/1"); + bayesNet.add(keys[12] | keys[14] = "3/1 3/1"); - // create a thin-tree Bayesnet, a la Jean-Guillaume - DiscreteBayesNet bayesNet; - bayesNet.add(key[14] % "1/3"); + bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4"); + bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1"); - bayesNet.add(key[13] | key[14] = "1/3 3/1"); - bayesNet.add(key[12] | key[14] = "3/1 3/1"); + bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1"); - bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); - bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); + bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1"); + bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1"); - bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); + // Create a BayesTree out of the Bayes net. + bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); + } +}; - bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); - bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); +/* ************************************************************************* */ +TEST(DiscreteBayesTree, ThinTree) { + const TestFixture self; + const auto& keys = self.keys; if (debug) { - GTSAM_PRINT(bayesNet); - bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); + GTSAM_PRINT(self.bayesNet); + self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); } // create a BayesTree out of a Bayes net - auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); if (debug) { - GTSAM_PRINT(*bayesTree); - bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); + GTSAM_PRINT(*self.bayesTree); + self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); } // Check frontals and parents for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) { - auto clique_i = (*bayesTree)[i]; + auto clique_i = (*self.bayesTree)[i]; EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); } - auto R = bayesTree->roots().front(); + auto R = self.bayesTree->roots().front(); // Check whether BN and BT give the same answer on all configurations - vector allPosbValues = cartesianProduct( - key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & - key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); + auto allPosbValues = + cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & + keys[5] & keys[6] & keys[7] & keys[8] & keys[9] & + keys[10] & keys[11] & keys[12] & keys[13] & keys[14]); for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; - double expected = bayesNet.evaluate(x); - double actual = bayesTree->evaluate(x); + DiscreteValues x = allPosbValues[i]; + double expected = self.bayesNet.evaluate(x); + double actual = self.bayesTree->evaluate(x); DOUBLES_EQUAL(expected, actual, 1e-9); } - // Calculate all some marginals for Values==all1 + // Calculate all some marginals for DiscreteValues==all1 Vector marginals = Vector::Zero(15); double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; - double px = bayesTree->evaluate(x); + DiscreteValues x = allPosbValues[i]; + double px = self.bayesTree->evaluate(x); for (size_t i = 0; i < 15; i++) if (x[i]) marginals[i] += px; if (x[12] && x[14]) { @@ -138,49 +151,49 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { } } } - DiscreteFactor::Values all1 = allPosbValues.back(); + DiscreteValues all1 = allPosbValues.back(); // check separator marginal P(S0) - auto clique = (*bayesTree)[0]; + auto clique = (*self.bayesTree)[0]; DiscreteFactorGraph separatorMarginal0 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); // check separator marginal P(S9), should be P(14) - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteFactorGraph separatorMarginal9 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); // check separator marginal of root, should be empty - clique = (*bayesTree)[11]; + clique = (*self.bayesTree)[11]; DiscreteFactorGraph separatorMarginal11 = clique->separatorMarginal(EliminateDiscrete); LONGS_EQUAL(0, separatorMarginal11.size()); // check shortcut P(S9||R) to root - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); LONGS_EQUAL(1, shortcut.size()); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S8||R) to root - clique = (*bayesTree)[8]; + clique = (*self.bayesTree)[8]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S2||R) to root - clique = (*bayesTree)[2]; + clique = (*self.bayesTree)[2]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S0||R) to root - clique = (*bayesTree)[0]; + clique = (*self.bayesTree)[0]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // calculate all shortcuts to root - DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); + DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); for (auto clique : cliques) { DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); if (debug) { @@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { // Check all marginals DiscreteFactor::shared_ptr marginalFactor; for (size_t i = 0; i < 15; i++) { - marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); + marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); double actual = (*marginalFactor)(all1); DOUBLES_EQUAL(marginals[i], actual, 1e-9); } @@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { DiscreteBayesNet::shared_ptr actualJoint; // Check joint P(8, 2) - actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete); DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9); // Check joint P(1, 2) - actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete); DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9); // Check joint P(2, 4) - actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete); DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 5) - actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete); DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 6) - actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete); DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 11) - actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete); DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9); } +/* ************************************************************************* */ +TEST(DiscreteBayesTree, Dot) { + const TestFixture self; + string actual = self.bayesTree->dot(); + EXPECT(actual == + "digraph G{\n" + "0[label=\"13,11,6,7\"];\n" + "0->1\n" + "1[label=\"14 : 11,13\"];\n" + "1->2\n" + "2[label=\"9,12 : 14\"];\n" + "2->3\n" + "3[label=\"3 : 9,12\"];\n" + "2->4\n" + "4[label=\"2 : 9,12\"];\n" + "2->5\n" + "5[label=\"8 : 12,14\"];\n" + "5->6\n" + "6[label=\"1 : 8,12\"];\n" + "5->7\n" + "7[label=\"0 : 8,12\"];\n" + "1->8\n" + "8[label=\"10 : 13,14\"];\n" + "8->9\n" + "9[label=\"5 : 10,13\"];\n" + "8->10\n" + "10[label=\"4 : 10,13\"];\n" + "}"); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3ac3ffc9eb..00ae1acd01 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -10,10 +10,11 @@ * -------------------------------------------------------------------------- */ /* - * @file testDecisionTreeFactor.cpp + * @file testDiscreteConditional.cpp * @brief unit tests for DiscreteConditional * @author Duy-Nguyen Ta - * @date Feb 14, 2011 + * @author Frank dellaert + * @date Feb 14, 2011 */ #include @@ -24,29 +25,27 @@ using namespace boost::assign; #include #include #include +#include using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST( DiscreteConditional, constructors) -{ - DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! - - DiscreteConditional::shared_ptr expected1 = // - boost::make_shared(X | Y = "1/1 2/3 1/4"); - EXPECT(expected1); - EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); - EXPECT(expected1->endParents() == expected1->end()); - EXPECT(expected1->endFrontals() == expected1->beginParents()); - +TEST(DiscreteConditional, constructors) { + DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! + + DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); + EXPECT(expected.endParents() == expected.end()); + EXPECT(expected.endFrontals() == expected.beginParents()); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(*expected1, actual1, 1e-9)); + EXPECT(assert_equal(expected, actual1, 1e-9)); - DecisionTreeFactor f2(X & Y & Z, - "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); + DecisionTreeFactor f2( + X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); } @@ -61,11 +60,10 @@ TEST(DiscreteConditional, constructors_alt_interface) { r2 += 2.0, 3.0; r3 += 1.0, 4.0; table += r1, r2, r3; - auto actual1 = boost::make_shared(X | Y = table); - EXPECT(actual1); + DiscreteConditional actual1(X, {Y}, table); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); - EXPECT(assert_equal(expected1, *actual1, 1e-9)); + EXPECT(assert_equal(expected1, actual1, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); @@ -102,9 +100,79 @@ TEST(DiscreteConditional, Combine) { c.push_back(boost::make_shared(A | B = "1/2 2/1")); c.push_back(boost::make_shared(B % "1/2")); DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional actual(2, factor); - auto expected = DiscreteConditional::Combine(c.begin(), c.end()); - EXPECT(assert_equal(*expected, actual, 1e-5)); + DiscreteConditional expected(2, factor); + auto actual = DiscreteConditional::Combine(c.begin(), c.end()); + EXPECT(assert_equal(expected, *actual, 1e-5)); +} + +/* ************************************************************************* */ +TEST(DiscreteConditional, likelihood) { + DiscreteKey X(0, 2), Y(1, 3); + DiscreteConditional conditional(X | Y = "2/8 4/6 5/5"); + + auto actual0 = conditional.likelihood(0); + DecisionTreeFactor expected0(Y, "0.2 0.4 0.5"); + EXPECT(assert_equal(expected0, *actual0, 1e-9)); + + auto actual1 = conditional.likelihood(1); + DecisionTreeFactor expected1(Y, "0.8 0.6 0.5"); + EXPECT(assert_equal(expected1, *actual1, 1e-9)); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, no parents. +TEST(DiscreteConditional, markdown_prior) { + DiscreteKey A(Symbol('x', 1), 3); + DiscreteConditional conditional(A % "1/2/2"); + string expected = + " *P(x1)*:\n\n" + "|x1|value|\n" + "|:-:|:-:|\n" + "|0|0.2|\n" + "|1|0.4|\n" + "|2|0.4|\n"; + string actual = conditional.markdown(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, multivalued. +TEST(DiscreteConditional, markdown_multivalued) { + DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5); + DiscreteConditional conditional( + A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); + string expected = + " *P(a1|b1)*:\n\n" + "|b1|0|1|2|\n" + "|:-:|:-:|:-:|:-:|\n" + "|0|0.02|0.88|0.1|\n" + "|1|0.02|0.2|0.78|\n" + "|2|0.33|0.33|0.34|\n" + "|3|0.33|0.33|0.34|\n" + "|4|0.95|0.02|0.03|\n"; + string actual = conditional.markdown(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, two parents. +TEST(DiscreteConditional, markdown) { + DiscreteKey A(2, 2), B(1, 2), C(0, 3); + DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); + string expected = + " *P(A|B,C)*:\n\n" + "|B|C|0|1|\n" + "|:-:|:-:|:-:|:-:|\n" + "|0|0|0|1|\n" + "|0|1|0.25|0.75|\n" + "|0|2|0.5|0.5|\n" + "|1|0|0.75|0.25|\n" + "|1|1|0|1|\n" + "|1|2|1|0|\n"; + vector names{"C", "B", "A"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = conditional.markdown(formatter); + EXPECT(actual == expected); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 1defd5acff..b6172382a6 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -81,8 +81,8 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { graph.add(P2, "0.9 0.6"); graph.add(P1 & P2, "4 1 10 4"); - // Instantiate Values - DiscreteFactor::Values values; + // Instantiate DiscreteValues + DiscreteValues values; values[0] = 1; values[1] = 1; @@ -167,10 +167,10 @@ TEST( DiscreteFactorGraph, test) // EXPECT(assert_equal(expected, *actual2)); // Test optimization - DiscreteFactor::Values expectedValues; + DiscreteValues expectedValues; insert(expectedValues)(0, 0)(1, 0)(2, 0); - DiscreteFactor::sharedValues actualValues = graph.optimize(); - EXPECT(assert_equal(expectedValues, *actualValues)); + auto actualValues = graph.optimize(); + EXPECT(assert_equal(expectedValues, actualValues)); } /* ************************************************************************* */ @@ -186,11 +186,11 @@ TEST( DiscreteFactorGraph, testMPE) // graph.product().print(); // DiscreteSequentialSolver(graph).eliminate()->print(); - DiscreteFactor::sharedValues actualMPE = graph.optimize(); + auto actualMPE = graph.optimize(); - DiscreteFactor::Values expectedMPE; + DiscreteValues expectedMPE; insert(expectedMPE)(0, 0)(1, 1)(2, 1); - EXPECT(assert_equal(expectedMPE, *actualMPE)); + EXPECT(assert_equal(expectedMPE, actualMPE)); } /* ************************************************************************* */ @@ -211,13 +211,13 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) // graph.product().potentials().dot("Darwiche-product"); // DiscreteSequentialSolver(graph).eliminate()->print(); - DiscreteFactor::Values expectedMPE; + DiscreteValues expectedMPE; insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); // Use the solver machinery. DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues actualMPE = chordal->optimize(); - EXPECT(assert_equal(expectedMPE, *actualMPE)); + auto actualMPE = chordal->optimize(); + EXPECT(assert_equal(expectedMPE, actualMPE)); // DiscreteConditional::shared_ptr root = chordal->back(); // EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9); @@ -244,8 +244,8 @@ ETree::shared_ptr eTree = ETree::Create(graph, structure); // eliminate normally and check solution DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete); // bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<"); -DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet); -EXPECT(assert_equal(expectedMPE, *actualMPE)); +auto actualMPE = optimize(*bayesNet); +EXPECT(assert_equal(expectedMPE, actualMPE)); // Approximate and check solution // DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate(); @@ -359,6 +359,67 @@ cout << unicorns; } #endif +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, Dot) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + string actual = graph.dot(); + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"0\"];\n" + " var1[label=\"1\"];\n" + " var2[label=\"2\"];\n" + "\n" + " var0--var1;\n" + " var0--var2;\n" + "}\n"; + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteFactorGraph, markdown) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + string expected = + "`DiscreteFactorGraph` of size 2\n" + "\n" + "factor 0:\n" + "|C|A|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|0.2|\n" + "|0|1|0.8|\n" + "|1|0|0.3|\n" + "|1|1|0.7|\n" + "\n" + "factor 1:\n" + "|C|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|0.1|\n" + "|0|1|0.9|\n" + "|1|0|0.4|\n" + "|1|1|0.6|\n\n"; + vector names{"C", "A", "B"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = graph.markdown(formatter); + EXPECT(actual == expected); + + // Make sure values are correctly displayed. + DiscreteValues values; + values[0] = 1; + values[1] = 0; + EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); +} /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteMarginals.cpp b/gtsam/discrete/tests/testDiscreteMarginals.cpp index e1eb92af3a..e75016b683 100644 --- a/gtsam/discrete/tests/testDiscreteMarginals.cpp +++ b/gtsam/discrete/tests/testDiscreteMarginals.cpp @@ -47,7 +47,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) { DiscreteMarginals marginals(graph); DiscreteFactor::shared_ptr actualC = marginals(Cathy.first); - DiscreteFactor::Values values; + DiscreteValues values; values[Cathy.first] = 0; EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6); @@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_chain ) { DiscreteMarginals marginals(graph); DiscreteFactor::shared_ptr actualC = marginals(key[2].first); - DiscreteFactor::Values values; + DiscreteValues values; values[key[2].first] = 0; EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4); @@ -164,11 +164,11 @@ TEST_UNSAFE(DiscreteMarginals, truss2) { graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8"); // Calculate the marginals by brute force - vector allPosbValues = + auto allPosbValues = cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]); Vector T = Z_5x1, F = Z_5x1; for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; + DiscreteValues x = allPosbValues[i]; double px = graph(x); for (size_t j = 0; j < 5; j++) if (x[j]) diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp new file mode 100644 index 0000000000..b91926cc05 --- /dev/null +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -0,0 +1,55 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file testDiscretePrior.cpp + * @brief unit tests for DiscretePrior + * @author Frank dellaert + * @date December 2021 + */ + +#include +#include +#include + +using namespace std; +using namespace gtsam; + +static const DiscreteKey X(0, 2); + +/* ************************************************************************* */ +TEST(DiscretePrior, constructors) { + DiscretePrior actual(X % "2/3"); + DecisionTreeFactor f(X, "0.4 0.6"); + DiscretePrior expected(f); + EXPECT(assert_equal(expected, actual, 1e-9)); +} + +/* ************************************************************************* */ +TEST(DiscretePrior, operator) { + DiscretePrior prior(X % "2/3"); + EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); + EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscretePrior, to_vector) { + DiscretePrior prior(X % "2/3"); + vector expected {0.4, 0.6}; + EXPECT(prior.pmf() == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index 049c455f72..737bd8aef0 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); /* ************************************************************************* */ TEST(testSignature, simple_conditional) { - Signature sig(X | Y = "1/1 2/3 1/4"); + Signature sig(X, {Y}, "1/1 2/3 1/4"); + CHECK(sig.table()); Signature::Table table = *sig.table(); vector row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; + LONGS_EQUAL(3, table.size()); CHECK(row[0] == table[0]); CHECK(row[1] == table[1]); CHECK(row[2] == table[2]); - DiscreteKey actKey = sig.key(); - LONGS_EQUAL(X.first, actKey.first); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + CHECK(sig.key() == X); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); + + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); } /* ************************************************************************* */ @@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) { table += row1, row2, row3; Signature sig(X | Y = table); - DiscreteKey actKey = sig.key(); - EXPECT_LONGS_EQUAL(X.first, actKey.first); + CHECK(sig.key() == X); + + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); + + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); +} - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); +/* ************************************************************************* */ +DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2); + +// Make sure we can create all signatures for Asia network with constructor. +TEST(testSignature, all_examples) { + DiscreteKey X(6, 2); + Signature a(A, {}, "99/1"); + Signature s(S, {}, "50/50"); + Signature t(T, {A}, "99/1 95/5"); + Signature l(L, {S}, "99/1 90/10"); + Signature b(B, {S}, "70/30 40/60"); + Signature e(E, {T, L}, "F F F 1"); + Signature x(X, {E}, "95/5 2/98"); +} + +// Make sure we can create all signatures for Asia network with operator magic. +TEST(testSignature, all_examples_magic) { + DiscreteKey X(6, 2); + Signature a(A % "99/1"); + Signature s(S % "50/50"); + Signature t(T | A = "99/1 95/5"); + Signature l(L | S = "99/1 90/10"); + Signature b(B | S = "70/30 40/60"); + Signature e((E | T, L) = "F F F 1"); + Signature x(X | E = "95/5 2/98"); +} - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); +// Check example from docs. +TEST(testSignature, doxygen_example) { + Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + Signature d1(D, {E, B}, table); + Signature d2((D | E, B) = "9/1 2/8 3/7 1/9"); + Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9"); + EXPECT(*(d1.table()) == table); + EXPECT(*(d2.table()) == table); + EXPECT(*(d3.table()) == table); } /* ************************************************************************* */ diff --git a/gtsam/geometry/PinholeCamera.h b/gtsam/geometry/PinholeCamera.h index c1f0b6b3fe..61e9f09098 100644 --- a/gtsam/geometry/PinholeCamera.h +++ b/gtsam/geometry/PinholeCamera.h @@ -312,6 +312,16 @@ class GTSAM_EXPORT PinholeCamera: public PinholeBaseK { return range(camera.pose(), Dcamera, Dother); } + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + return K_.K() * PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4); + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_.fx());; + } + private: /** Serialization function */ diff --git a/gtsam/geometry/PinholePose.h b/gtsam/geometry/PinholePose.h index 7a0b08227c..b4999af7c8 100644 --- a/gtsam/geometry/PinholePose.h +++ b/gtsam/geometry/PinholePose.h @@ -121,6 +121,13 @@ class GTSAM_EXPORT PinholeBaseK: public PinholeBase { return _project(pw, Dpose, Dpoint, Dcal); } + /// project a 3D point from world coordinates into the image + Point2 reprojectionError(const Point3& pw, const Point2& measured, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none, + OptionalJacobian<2, DimK> Dcal = boost::none) const { + return Point2(_project(pw, Dpose, Dpoint, Dcal) - measured); + } + /// project a point at infinity from world coordinates into the image Point2 project(const Unit3& pw, OptionalJacobian<2, 6> Dpose = boost::none, OptionalJacobian<2, 2> Dpoint = boost::none, @@ -159,7 +166,6 @@ class GTSAM_EXPORT PinholeBaseK: public PinholeBase { return result; } - /// backproject a 2-dimensional point to a 3-dimensional point at infinity Unit3 backprojectPointAtInfinity(const Point2& p) const { const Point2 pn = calibration().calibrate(p); @@ -410,6 +416,16 @@ class PinholePose: public PinholeBaseK { return PinholePose(); // assumes that the default constructor is valid } + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + Matrix34 P = Matrix34(PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4)); + return K_->K() * P; + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_->fx());; + } /// @} private: diff --git a/gtsam/geometry/Quaternion.h b/gtsam/geometry/Quaternion.h index 1557a09dbd..2ef47d58e3 100644 --- a/gtsam/geometry/Quaternion.h +++ b/gtsam/geometry/Quaternion.h @@ -117,13 +117,23 @@ struct traits { omega = (-8. / 3. - 2. / 3. * qw) * q.vec(); } else { // Normal, away from zero case - _Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw); - // Important: convert to [-pi,pi] to keep error continuous - if (angle > M_PI) - angle -= twoPi; - else if (angle < -M_PI) - angle += twoPi; - omega = (angle / s) * q.vec(); + if (qw > 0) { + _Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw); + // Important: convert to [-pi,pi] to keep error continuous + if (angle > M_PI) + angle -= twoPi; + else if (angle < -M_PI) + angle += twoPi; + omega = (angle / s) * q.vec(); + } else { + // Make sure that we are using a canonical quaternion with w > 0 + _Scalar angle = 2 * acos(-qw), s = sqrt(1 - qw * qw); + if (angle > M_PI) + angle -= twoPi; + else if (angle < -M_PI) + angle += twoPi; + omega = (angle / s) * -q.vec(); + } } if(H) *H = SO3::LogmapDerivative(omega.template cast()); diff --git a/gtsam/geometry/SphericalCamera.cpp b/gtsam/geometry/SphericalCamera.cpp new file mode 100644 index 0000000000..58a29dc092 --- /dev/null +++ b/gtsam/geometry/SphericalCamera.cpp @@ -0,0 +1,109 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#include + +using namespace std; + +namespace gtsam { + +/* ************************************************************************* */ +bool SphericalCamera::equals(const SphericalCamera& camera, double tol) const { + return pose_.equals(camera.pose(), tol); +} + +/* ************************************************************************* */ +void SphericalCamera::print(const string& s) const { pose_.print(s + ".pose"); } + +/* ************************************************************************* */ +pair SphericalCamera::projectSafe(const Point3& pw) const { + const Point3 pc = pose().transformTo(pw); + Unit3 pu = Unit3::FromPoint3(pc); + return make_pair(pu, pc.norm() > 1e-8); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project2(const Point3& pw, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 3> Dpoint) const { + Matrix36 Dtf_pose; + Matrix3 Dtf_point; // calculated by transformTo if needed + const Point3 pc = + pose().transformTo(pw, Dpose ? &Dtf_pose : 0, Dpoint ? &Dtf_point : 0); + + if (pc.norm() <= 1e-8) throw("point cannot be at the center of the camera"); + + Matrix23 Dunit; // calculated by FromPoint3 if needed + Unit3 pu = Unit3::FromPoint3(Point3(pc), Dpoint ? &Dunit : 0); + + if (Dpose) *Dpose = Dunit * Dtf_pose; // 2x3 * 3x6 = 2x6 + if (Dpoint) *Dpoint = Dunit * Dtf_point; // 2x3 * 3x3 = 2x3 + return pu; +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project2(const Unit3& pwu, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 2> Dpoint) const { + Matrix23 Dtf_rot; + Matrix2 Dtf_point; // calculated by transformTo if needed + const Unit3 pu = pose().rotation().unrotate(pwu, Dpose ? &Dtf_rot : 0, + Dpoint ? &Dtf_point : 0); + + if (Dpose) + *Dpose << Dtf_rot, Matrix::Zero(2, 3); // 2x6 (translation part is zero) + if (Dpoint) *Dpoint = Dtf_point; // 2x2 + return pu; +} + +/* ************************************************************************* */ +Point3 SphericalCamera::backproject(const Unit3& pu, const double depth) const { + return pose().transformFrom(depth * pu); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::backprojectPointAtInfinity(const Unit3& p) const { + return pose().rotation().rotate(p); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project(const Point3& point, + OptionalJacobian<2, 6> Dcamera, + OptionalJacobian<2, 3> Dpoint) const { + return project2(point, Dcamera, Dpoint); +} + +/* ************************************************************************* */ +Vector2 SphericalCamera::reprojectionError( + const Point3& point, const Unit3& measured, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 3> Dpoint) const { + // project point + if (Dpose || Dpoint) { + Matrix26 H_project_pose; + Matrix23 H_project_point; + Matrix22 H_error; + Unit3 projected = project2(point, H_project_pose, H_project_point); + Vector2 error = measured.errorVector(projected, boost::none, H_error); + if (Dpose) *Dpose = H_error * H_project_pose; + if (Dpoint) *Dpoint = H_error * H_project_point; + return error; + } else { + return measured.errorVector(project2(point, Dpose, Dpoint)); + } +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/geometry/SphericalCamera.h b/gtsam/geometry/SphericalCamera.h new file mode 100644 index 0000000000..4880423d32 --- /dev/null +++ b/gtsam/geometry/SphericalCamera.h @@ -0,0 +1,241 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace gtsam { + +/** + * Empty calibration. Only needed to play well with other cameras + * (e.g., when templating functions wrt cameras), since other cameras + * have constuctors in the form ‘camera(pose,calibration)’ + * @addtogroup geometry + * \nosubgrouping + */ +class GTSAM_EXPORT EmptyCal { + public: + enum { dimension = 0 }; + EmptyCal() {} + virtual ~EmptyCal() = default; + using shared_ptr = boost::shared_ptr; + + /// return DOF, dimensionality of tangent space + inline static size_t Dim() { return dimension; } + + void print(const std::string& s) const { + std::cout << "empty calibration: " << s << std::endl; + } + + private: + /// Serialization function + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "EmptyCal", boost::serialization::base_object(*this)); + } +}; + +/** + * A spherical camera class that has a Pose3 and measures bearing vectors. + * The camera has an ‘Empty’ calibration and the only 6 dof are the pose + * @addtogroup geometry + * \nosubgrouping + */ +class GTSAM_EXPORT SphericalCamera { + public: + enum { dimension = 6 }; + + using Measurement = Unit3; + using MeasurementVector = std::vector; + using CalibrationType = EmptyCal; + + private: + Pose3 pose_; ///< 3D pose of camera + + protected: + EmptyCal::shared_ptr emptyCal_; + + public: + /// @} + /// @name Standard Constructors + /// @{ + + /// Default constructor + SphericalCamera() + : pose_(Pose3::identity()), emptyCal_(boost::make_shared()) {} + + /// Constructor with pose + explicit SphericalCamera(const Pose3& pose) + : pose_(pose), emptyCal_(boost::make_shared()) {} + + /// Constructor with empty intrinsics (needed for smart factors) + explicit SphericalCamera(const Pose3& pose, + const EmptyCal::shared_ptr& cal) + : pose_(pose), emptyCal_(cal) {} + + /// @} + /// @name Advanced Constructors + /// @{ + explicit SphericalCamera(const Vector& v) : pose_(Pose3::Expmap(v)) {} + + /// Default destructor + virtual ~SphericalCamera() = default; + + /// return shared pointer to calibration + const EmptyCal::shared_ptr& sharedCalibration() const { + return emptyCal_; + } + + /// return calibration + const EmptyCal& calibration() const { return *emptyCal_; } + + /// @} + /// @name Testable + /// @{ + + /// assert equality up to a tolerance + bool equals(const SphericalCamera& camera, double tol = 1e-9) const; + + /// print + virtual void print(const std::string& s = "SphericalCamera") const; + + /// @} + /// @name Standard Interface + /// @{ + + /// return pose, constant version + const Pose3& pose() const { return pose_; } + + /// get rotation + const Rot3& rotation() const { return pose_.rotation(); } + + /// get translation + const Point3& translation() const { return pose_.translation(); } + + // /// return pose, with derivative + // const Pose3& getPose(OptionalJacobian<6, 6> H) const; + + /// @} + /// @name Transformations and measurement functions + /// @{ + + /// Project a point into the image and check depth + std::pair projectSafe(const Point3& pw) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D point in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project2(const Point3& pw, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D direction in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project2(const Unit3& pwu, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 2> Dpoint = boost::none) const; + + /// backproject a 2-dimensional point to a 3-dimensional point at given depth + Point3 backproject(const Unit3& p, const double depth) const; + + /// backproject point at infinity + Unit3 backprojectPointAtInfinity(const Unit3& p) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D point in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project(const Point3& point, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + + /** Compute reprojection error for a given 3D point in world coordinates + * @param point 3D point in world coordinates + * @return the tangent space error between the projection and the measurement + */ + Vector2 reprojectionError(const Point3& point, const Unit3& measured, + OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + /// @} + + /// move a cameras according to d + SphericalCamera retract(const Vector6& d) const { + return SphericalCamera(pose().retract(d)); + } + + /// return canonical coordinate + Vector6 localCoordinates(const SphericalCamera& p) const { + return pose().localCoordinates(p.pose()); + } + + /// for Canonical + static SphericalCamera identity() { + return SphericalCamera( + Pose3::identity()); // assumes that the default constructor is valid + } + + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + return Matrix34(pose_.inverse().matrix().block(0, 0, 3, 4)); + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension, 1>::Constant(0.0); + } + + /// @deprecated + size_t dim() const { return 6; } + + /// @deprecated + static size_t Dim() { return 6; } + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_NVP(pose_); + } + + public: + GTSAM_MAKE_ALIGNED_OPERATOR_NEW +}; +// end of class SphericalCamera + +template <> +struct traits : public internal::LieGroup {}; + +template <> +struct traits : public internal::LieGroup {}; + +} // namespace gtsam diff --git a/gtsam/geometry/StereoCamera.h b/gtsam/geometry/StereoCamera.h index 3b5bdaefc0..c53fc11c99 100644 --- a/gtsam/geometry/StereoCamera.h +++ b/gtsam/geometry/StereoCamera.h @@ -170,6 +170,11 @@ class GTSAM_EXPORT StereoCamera { OptionalJacobian<3, 3> H2 = boost::none, OptionalJacobian<3, 0> H3 = boost::none) const; + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_->fx());; + } + /// @} private: diff --git a/gtsam/geometry/geometry.i b/gtsam/geometry/geometry.i index a40951d3e7..0def842651 100644 --- a/gtsam/geometry/geometry.i +++ b/gtsam/geometry/geometry.i @@ -27,9 +27,6 @@ class Point2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Point2Pairs { @@ -104,9 +101,6 @@ class StereoPoint2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -131,9 +125,6 @@ class Point3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Point3Pairs { @@ -191,9 +182,6 @@ class Rot2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -372,9 +360,6 @@ class Rot3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -433,9 +418,6 @@ class Pose2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; boost::optional align(const gtsam::Point2Pairs& pairs); @@ -502,9 +484,6 @@ class Pose3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Pose3Pairs { @@ -547,9 +526,6 @@ class Unit3 { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // enabling function to compare objects bool equals(const gtsam::Unit3& expected, double tol) const; }; @@ -611,9 +587,6 @@ class Cal3_S2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -642,9 +615,6 @@ virtual class Cal3DS2_Base { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -668,9 +638,6 @@ virtual class Cal3DS2 : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -705,9 +672,6 @@ virtual class Cal3Unified : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -750,9 +714,6 @@ class Cal3Fisheye { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -811,9 +772,6 @@ class Cal3Bundler { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -847,9 +805,6 @@ class CalibratedCamera { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -889,9 +844,6 @@ class PinholeCamera { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -962,9 +914,6 @@ class StereoCamera { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include diff --git a/gtsam/geometry/tests/testSphericalCamera.cpp b/gtsam/geometry/tests/testSphericalCamera.cpp new file mode 100644 index 0000000000..4bc851f351 --- /dev/null +++ b/gtsam/geometry/tests/testSphericalCamera.cpp @@ -0,0 +1,163 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#include +#include +#include +#include + +#include +#include + +using namespace std::placeholders; +using namespace std; +using namespace gtsam; + +typedef SphericalCamera Camera; + +// static const Cal3_S2 K(625, 625, 0, 0, 0); +// +static const Pose3 pose(Rot3(Vector3(1, -1, -1).asDiagonal()), + Point3(0, 0, 0.5)); +static const Camera camera(pose); +// +static const Pose3 pose1(Rot3(), Point3(0, 1, 0.5)); +static const Camera camera1(pose1); + +static const Point3 point1(-0.08, -0.08, 0.0); +static const Point3 point2(-0.08, 0.08, 0.0); +static const Point3 point3(0.08, 0.08, 0.0); +static const Point3 point4(0.08, -0.08, 0.0); + +// manually computed in matlab +static const Unit3 bearing1(-0.156054862928174, 0.156054862928174, + 0.975342893301088); +static const Unit3 bearing2(-0.156054862928174, -0.156054862928174, + 0.975342893301088); +static const Unit3 bearing3(0.156054862928174, -0.156054862928174, + 0.975342893301088); +static const Unit3 bearing4(0.156054862928174, 0.156054862928174, + 0.975342893301088); + +static double depth = 0.512640224719052; +/* ************************************************************************* */ +TEST(SphericalCamera, constructor) { + EXPECT(assert_equal(pose, camera.pose())); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, project) { + // expected from manual calculation in Matlab + EXPECT(assert_equal(camera.project(point1), bearing1)); + EXPECT(assert_equal(camera.project(point2), bearing2)); + EXPECT(assert_equal(camera.project(point3), bearing3)); + EXPECT(assert_equal(camera.project(point4), bearing4)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, backproject) { + EXPECT(assert_equal(camera.backproject(bearing1, depth), point1)); + EXPECT(assert_equal(camera.backproject(bearing2, depth), point2)); + EXPECT(assert_equal(camera.backproject(bearing3, depth), point3)); + EXPECT(assert_equal(camera.backproject(bearing4, depth), point4)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, backproject2) { + Point3 origin(0, 0, 0); + Rot3 rot(1., 0., 0., 0., 0., 1., 0., -1., 0.); // a camera1 looking down + Camera camera(Pose3(rot, origin)); + + Point3 actual = camera.backproject(Unit3(0, 0, 1), 1.); + Point3 expected(0., 1., 0.); + pair x = camera.projectSafe(expected); + + EXPECT(assert_equal(expected, actual)); + EXPECT(assert_equal(Unit3(0, 0, 1), x.first)); + EXPECT(x.second); +} + +/* ************************************************************************* */ +static Unit3 project3(const Pose3& pose, const Point3& point) { + return Camera(pose).project(point); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, Dproject) { + Matrix Dpose, Dpoint; + Unit3 result = camera.project(point1, Dpose, Dpoint); + Matrix numerical_pose = numericalDerivative21(project3, pose, point1); + Matrix numerical_point = numericalDerivative22(project3, pose, point1); + EXPECT(assert_equal(bearing1, result)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +static Vector2 reprojectionError2(const Pose3& pose, const Point3& point, + const Unit3& measured) { + return Camera(pose).reprojectionError(point, measured); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, reprojectionError) { + Matrix Dpose, Dpoint; + Vector2 result = camera.reprojectionError(point1, bearing1, Dpose, Dpoint); + Matrix numerical_pose = + numericalDerivative31(reprojectionError2, pose, point1, bearing1); + Matrix numerical_point = + numericalDerivative32(reprojectionError2, pose, point1, bearing1); + EXPECT(assert_equal(Vector2(0.0, 0.0), result)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, reprojectionError_noisy) { + Matrix Dpose, Dpoint; + Unit3 bearing_noisy = bearing1.retract(Vector2(0.01, 0.05)); + Vector2 result = + camera.reprojectionError(point1, bearing_noisy, Dpose, Dpoint); + Matrix numerical_pose = + numericalDerivative31(reprojectionError2, pose, point1, bearing_noisy); + Matrix numerical_point = + numericalDerivative32(reprojectionError2, pose, point1, bearing_noisy); + EXPECT(assert_equal(Vector2(-0.050282, 0.00833482), result, 1e-5)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +// Add a test with more arbitrary rotation +TEST(SphericalCamera, Dproject2) { + static const Pose3 pose1(Rot3::Ypr(0.1, -0.1, 0.4), Point3(0, 0, -10)); + static const Camera camera(pose1); + Matrix Dpose, Dpoint; + camera.project2(point1, Dpose, Dpoint); + Matrix numerical_pose = numericalDerivative21(project3, pose1, point1); + Matrix numerical_point = numericalDerivative22(project3, pose1, point1); + CHECK(assert_equal(numerical_pose, Dpose, 1e-7)); + CHECK(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/geometry/tests/testTriangulation.cpp b/gtsam/geometry/tests/testTriangulation.cpp index 4f71a48dad..5fdb911d02 100644 --- a/gtsam/geometry/tests/testTriangulation.cpp +++ b/gtsam/geometry/tests/testTriangulation.cpp @@ -10,22 +10,23 @@ * -------------------------------------------------------------------------- */ /** - * testTriangulation.cpp - * - * Created on: July 30th, 2013 - * Author: cbeall3 + * @file testTriangulation.cpp + * @brief triangulation utilities + * @date July 30th, 2013 + * @author Chris Beall (cbeall3) + * @author Luca Carlone */ -#include +#include +#include +#include #include +#include #include -#include -#include -#include -#include +#include #include -#include - +#include +#include #include #include @@ -36,7 +37,7 @@ using namespace boost::assign; // Some common constants -static const boost::shared_ptr sharedCal = // +static const boost::shared_ptr sharedCal = // boost::make_shared(1500, 1200, 0, 640, 480); // Looking along X-axis, 1 meter above ground plane (x-y) @@ -57,8 +58,7 @@ Point2 z2 = camera2.project(landmark); //****************************************************************************** // Simple test with a well-behaved two camera situation -TEST( triangulation, twoPoses) { - +TEST(triangulation, twoPoses) { vector poses; Point2Vector measurements; @@ -69,36 +69,36 @@ TEST( triangulation, twoPoses) { // 1. Test simple DLT, perfect in no noise situation bool optimize = false; - boost::optional actual1 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual1 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(landmark, *actual1, 1e-7)); // 2. test with optimization on, same answer optimize = true; - boost::optional actual2 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(landmark, *actual2, 1e-7)); - // 3. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 3. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); optimize = false; - boost::optional actual3 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual3 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-4)); // 4. Now with optimization on optimize = true; - boost::optional actual4 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual4 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-4)); } //****************************************************************************** // Similar, but now with Bundler calibration -TEST( triangulation, twoPosesBundler) { - - boost::shared_ptr bundlerCal = // +TEST(triangulation, twoPosesBundler) { + boost::shared_ptr bundlerCal = // boost::make_shared(1500, 0, 0, 640, 480); PinholeCamera camera1(pose1, *bundlerCal); PinholeCamera camera2(pose2, *bundlerCal); @@ -116,37 +116,38 @@ TEST( triangulation, twoPosesBundler) { bool optimize = true; double rank_tol = 1e-9; - boost::optional actual = // - triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); + boost::optional actual = // + triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); EXPECT(assert_equal(landmark, *actual, 1e-7)); // Add some noise and try again measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); + boost::optional actual2 = // + triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19847), *actual2, 1e-4)); } //****************************************************************************** -TEST( triangulation, fourPoses) { +TEST(triangulation, fourPoses) { vector poses; Point2Vector measurements; poses += pose1, pose2; measurements += z1, z2; - boost::optional actual = triangulatePoint3(poses, sharedCal, - measurements); + boost::optional actual = + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *actual, 1e-2)); - // 2. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 2. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(poses, sharedCal, measurements); + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *actual2, 1e-2)); // 3. Add a slightly rotated third camera above, again with measurement noise @@ -157,13 +158,13 @@ TEST( triangulation, fourPoses) { poses += pose3; measurements += z3 + Point2(0.1, -0.1); - boost::optional triangulated_3cameras = // - triangulatePoint3(poses, sharedCal, measurements); + boost::optional triangulated_3cameras = // + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *triangulated_3cameras, 1e-2)); // Again with nonlinear optimization - boost::optional triangulated_3cameras_opt = triangulatePoint3(poses, - sharedCal, measurements, 1e-9, true); + boost::optional triangulated_3cameras_opt = + triangulatePoint3(poses, sharedCal, measurements, 1e-9, true); EXPECT(assert_equal(landmark, *triangulated_3cameras_opt, 1e-2)); // 4. Test failure: Add a 4th camera facing the wrong way @@ -176,13 +177,13 @@ TEST( triangulation, fourPoses) { poses += pose4; measurements += Point2(400, 400); - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationCheiralityException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationCheiralityException); #endif } //****************************************************************************** -TEST( triangulation, fourPoses_distinct_Ks) { +TEST(triangulation, fourPoses_distinct_Ks) { Cal3_S2 K1(1500, 1200, 0, 640, 480); // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, K1); @@ -195,22 +196,23 @@ TEST( triangulation, fourPoses_distinct_Ks) { Point2 z1 = camera1.project(landmark); Point2 z2 = camera2.project(landmark); - CameraSet > cameras; + CameraSet> cameras; Point2Vector measurements; cameras += camera1, camera2; measurements += z1, z2; - boost::optional actual = // - triangulatePoint3(cameras, measurements); + boost::optional actual = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *actual, 1e-2)); - // 2. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 2. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(cameras, measurements); + boost::optional actual2 = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *actual2, 1e-2)); // 3. Add a slightly rotated third camera above, again with measurement noise @@ -222,13 +224,13 @@ TEST( triangulation, fourPoses_distinct_Ks) { cameras += camera3; measurements += z3 + Point2(0.1, -0.1); - boost::optional triangulated_3cameras = // - triangulatePoint3(cameras, measurements); + boost::optional triangulated_3cameras = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *triangulated_3cameras, 1e-2)); // Again with nonlinear optimization - boost::optional triangulated_3cameras_opt = triangulatePoint3(cameras, - measurements, 1e-9, true); + boost::optional triangulated_3cameras_opt = + triangulatePoint3(cameras, measurements, 1e-9, true); EXPECT(assert_equal(landmark, *triangulated_3cameras_opt, 1e-2)); // 4. Test failure: Add a 4th camera facing the wrong way @@ -241,13 +243,13 @@ TEST( triangulation, fourPoses_distinct_Ks) { cameras += camera4; measurements += Point2(400, 400); - CHECK_EXCEPTION(triangulatePoint3(cameras, measurements), - TriangulationCheiralityException); + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements), + TriangulationCheiralityException); #endif } //****************************************************************************** -TEST( triangulation, outliersAndFarLandmarks) { +TEST(triangulation, outliersAndFarLandmarks) { Cal3_S2 K1(1500, 1200, 0, 640, 480); // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, K1); @@ -260,24 +262,29 @@ TEST( triangulation, outliersAndFarLandmarks) { Point2 z1 = camera1.project(landmark); Point2 z2 = camera2.project(landmark); - CameraSet > cameras; + CameraSet> cameras; Point2Vector measurements; cameras += camera1, camera2; measurements += z1, z2; - double landmarkDistanceThreshold = 10; // landmark is closer than that - TriangulationParameters params(1.0, false, landmarkDistanceThreshold); // all default except landmarkDistanceThreshold - TriangulationResult actual = triangulateSafe(cameras,measurements,params); + double landmarkDistanceThreshold = 10; // landmark is closer than that + TriangulationParameters params( + 1.0, false, landmarkDistanceThreshold); // all default except + // landmarkDistanceThreshold + TriangulationResult actual = triangulateSafe(cameras, measurements, params); EXPECT(assert_equal(landmark, *actual, 1e-2)); EXPECT(actual.valid()); - landmarkDistanceThreshold = 4; // landmark is farther than that - TriangulationParameters params2(1.0, false, landmarkDistanceThreshold); // all default except landmarkDistanceThreshold - actual = triangulateSafe(cameras,measurements,params2); + landmarkDistanceThreshold = 4; // landmark is farther than that + TriangulationParameters params2( + 1.0, false, landmarkDistanceThreshold); // all default except + // landmarkDistanceThreshold + actual = triangulateSafe(cameras, measurements, params2); EXPECT(actual.farPoint()); - // 3. Add a slightly rotated third camera above with a wrong measurement (OUTLIER) + // 3. Add a slightly rotated third camera above with a wrong measurement + // (OUTLIER) Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1)); Cal3_S2 K3(700, 500, 0, 640, 480); PinholeCamera camera3(pose3, K3); @@ -286,21 +293,23 @@ TEST( triangulation, outliersAndFarLandmarks) { cameras += camera3; measurements += z3 + Point2(10, -10); - landmarkDistanceThreshold = 10; // landmark is closer than that - double outlierThreshold = 100; // loose, the outlier is going to pass - TriangulationParameters params3(1.0, false, landmarkDistanceThreshold,outlierThreshold); - actual = triangulateSafe(cameras,measurements,params3); + landmarkDistanceThreshold = 10; // landmark is closer than that + double outlierThreshold = 100; // loose, the outlier is going to pass + TriangulationParameters params3(1.0, false, landmarkDistanceThreshold, + outlierThreshold); + actual = triangulateSafe(cameras, measurements, params3); EXPECT(actual.valid()); // now set stricter threshold for outlier rejection - outlierThreshold = 5; // tighter, the outlier is not going to pass - TriangulationParameters params4(1.0, false, landmarkDistanceThreshold,outlierThreshold); - actual = triangulateSafe(cameras,measurements,params4); + outlierThreshold = 5; // tighter, the outlier is not going to pass + TriangulationParameters params4(1.0, false, landmarkDistanceThreshold, + outlierThreshold); + actual = triangulateSafe(cameras, measurements, params4); EXPECT(actual.outlier()); } //****************************************************************************** -TEST( triangulation, twoIdenticalPoses) { +TEST(triangulation, twoIdenticalPoses) { // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, *sharedCal); @@ -313,12 +322,12 @@ TEST( triangulation, twoIdenticalPoses) { poses += pose1, pose1; measurements += z1, z1; - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationUnderconstrainedException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationUnderconstrainedException); } //****************************************************************************** -TEST( triangulation, onePose) { +TEST(triangulation, onePose) { // we expect this test to fail with a TriangulationUnderconstrainedException // because there's only one camera observation @@ -326,28 +335,26 @@ TEST( triangulation, onePose) { Point2Vector measurements; poses += Pose3(); - measurements += Point2(0,0); + measurements += Point2(0, 0); - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationUnderconstrainedException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationUnderconstrainedException); } //****************************************************************************** -TEST( triangulation, StereotriangulateNonlinear ) { - - auto stereoK = boost::make_shared(1733.75, 1733.75, 0, 689.645, 508.835, 0.0699612); +TEST(triangulation, StereotriangulateNonlinear) { + auto stereoK = boost::make_shared(1733.75, 1733.75, 0, 689.645, + 508.835, 0.0699612); // two camera poses m1, m2 Matrix4 m1, m2; - m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, - 0.592783835, -0.77156583, 0.230856632, 66.2186159, - 0.116517574, -0.201470143, -0.9725393, -4.28382528, - 0, 0, 0, 1; + m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, 0.592783835, + -0.77156583, 0.230856632, 66.2186159, 0.116517574, -0.201470143, + -0.9725393, -4.28382528, 0, 0, 0, 1; - m2 << -0.955959025, -0.29288915, -0.0189328569, 45.7169799, - -0.29277519, 0.947083213, 0.131587097, 65.843136, - -0.0206094928, 0.131334858, -0.991123524, -4.3525033, - 0, 0, 0, 1; + m2 << -0.955959025, -0.29288915, -0.0189328569, 45.7169799, -0.29277519, + 0.947083213, 0.131587097, 65.843136, -0.0206094928, 0.131334858, + -0.991123524, -4.3525033, 0, 0, 0, 1; typedef CameraSet Cameras; Cameras cameras; @@ -358,18 +365,19 @@ TEST( triangulation, StereotriangulateNonlinear ) { measurements += StereoPoint2(226.936, 175.212, 424.469); measurements += StereoPoint2(339.571, 285.547, 669.973); - Point3 initial = Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191 + Point3 initial = + Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191 - Point3 actual = triangulateNonlinear(cameras, measurements, initial); + Point3 actual = triangulateNonlinear(cameras, measurements, initial); - Point3 expected(46.0484569, 66.4710686, -6.55046613); // error: 0.763510644187 + Point3 expected(46.0484569, 66.4710686, + -6.55046613); // error: 0.763510644187 EXPECT(assert_equal(expected, actual, 1e-4)); - // regular stereo factor comparison - expect very similar result as above { - typedef GenericStereoFactor StereoFactor; + typedef GenericStereoFactor StereoFactor; Values values; values.insert(Symbol('x', 1), Pose3(m1)); @@ -378,17 +386,19 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - graph.emplace_shared(measurements[0], unit, Symbol('x',1), Symbol('l',1), stereoK); - graph.emplace_shared(measurements[1], unit, Symbol('x',2), Symbol('l',1), stereoK); + graph.emplace_shared(measurements[0], unit, Symbol('x', 1), + Symbol('l', 1), stereoK); + graph.emplace_shared(measurements[1], unit, Symbol('x', 2), + Symbol('l', 1), stereoK); const SharedDiagonal posePrior = noiseModel::Isotropic::Sigma(6, 1e-9); - graph.addPrior(Symbol('x',1), Pose3(m1), posePrior); - graph.addPrior(Symbol('x',2), Pose3(m2), posePrior); + graph.addPrior(Symbol('x', 1), Pose3(m1), posePrior); + graph.addPrior(Symbol('x', 2), Pose3(m2), posePrior); LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } // use Triangulation Factor directly - expect same result as above @@ -399,13 +409,15 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - graph.emplace_shared >(cameras[0], measurements[0], unit, Symbol('l',1)); - graph.emplace_shared >(cameras[1], measurements[1], unit, Symbol('l',1)); + graph.emplace_shared>( + cameras[0], measurements[0], unit, Symbol('l', 1)); + graph.emplace_shared>( + cameras[1], measurements[1], unit, Symbol('l', 1)); LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } // use ExpressionFactor - expect same result as above @@ -416,11 +428,13 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - Expression point_(Symbol('l',1)); + Expression point_(Symbol('l', 1)); Expression camera0_(cameras[0]); Expression camera1_(cameras[1]); - Expression project0_(camera0_, &StereoCamera::project2, point_); - Expression project1_(camera1_, &StereoCamera::project2, point_); + Expression project0_(camera0_, &StereoCamera::project2, + point_); + Expression project1_(camera1_, &StereoCamera::project2, + point_); graph.addExpressionFactor(unit, measurements[0], project0_); graph.addExpressionFactor(unit, measurements[1], project1_); @@ -428,10 +442,172 @@ TEST( triangulation, StereotriangulateNonlinear ) { LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } } +//****************************************************************************** +// Simple test with a well-behaved two camera situation +TEST(triangulation, twoPoses_sphericalCamera) { + vector poses; + std::vector measurements; + + // Project landmark into two cameras and triangulate + SphericalCamera cam1(pose1); + SphericalCamera cam2(pose2); + Unit3 u1 = cam1.project(landmark); + Unit3 u2 = cam2.project(landmark); + + poses += pose1, pose2; + measurements += u1, u2; + + CameraSet cameras; + cameras.push_back(cam1); + cameras.push_back(cam2); + + double rank_tol = 1e-9; + + // 1. Test linear triangulation via DLT + auto projection_matrices = projectionMatricesFromCameras(cameras); + Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); + EXPECT(assert_equal(landmark, point, 1e-7)); + + // 2. Test nonlinear triangulation + point = triangulateNonlinear(cameras, measurements, point); + EXPECT(assert_equal(landmark, point, 1e-7)); + + // 3. Test simple DLT, now within triangulatePoint3 + bool optimize = false; + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmark, *actual1, 1e-7)); + + // 4. test with optimization on, same answer + optimize = true; + boost::optional actual2 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmark, *actual2, 1e-7)); + + // 5. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) + measurements.at(0) = + u1.retract(Vector2(0.01, 0.05)); // note: perturbation smaller for Unit3 + measurements.at(1) = u2.retract(Vector2(-0.02, 0.03)); + optimize = false; + boost::optional actual3 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(Point3(5.9432, 0.654319, 1.48192), *actual3, 1e-3)); + + // 6. Now with optimization on + optimize = true; + boost::optional actual4 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(Point3(5.9432, 0.654334, 1.48192), *actual4, 1e-3)); +} + +//****************************************************************************** +TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) { + vector poses; + std::vector measurements; + + // Project landmark into two cameras and triangulate + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(2.0, 0.0, 0.0)); // 2m in front of poseA + Point3 landmarkL( + 1.0, -1.0, + 0.0); // 1m to the right of both cameras, in front of poseA, behind poseB + SphericalCamera cam1(poseA); + SphericalCamera cam2(poseB); + Unit3 u1 = cam1.project(landmarkL); + Unit3 u2 = cam2.project(landmarkL); + + EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, 1.0)), u1, + 1e-7)); // in front and to the right of PoseA + EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, -1.0)), u2, + 1e-7)); // behind and to the right of PoseB + + poses += pose1, pose2; + measurements += u1, u2; + + CameraSet cameras; + cameras.push_back(cam1); + cameras.push_back(cam2); + + double rank_tol = 1e-9; + + { + // 1. Test simple DLT, when 1 point is behind spherical camera + bool optimize = false; +#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements, + rank_tol, optimize), + TriangulationCheiralityException); +#else // otherwise project should not throw the exception + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmarkL, *actual1, 1e-7)); +#endif + } + { + // 2. test with optimization on, same answer + bool optimize = true; +#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements, + rank_tol, optimize), + TriangulationCheiralityException); +#else // otherwise project should not throw the exception + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmarkL, *actual1, 1e-7)); +#endif + } +} + +//****************************************************************************** +TEST(triangulation, reprojectionError_cameraComparison) { + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Point3 landmarkL(5.0, 0.0, 0.0); // 1m in front of poseA + SphericalCamera sphericalCamera(poseA); + Unit3 u = sphericalCamera.project(landmarkL); + + static Cal3_S2::shared_ptr sharedK(new Cal3_S2(60, 640, 480)); + PinholePose pinholeCamera(poseA, sharedK); + Vector2 px = pinholeCamera.project(landmarkL); + + // add perturbation and compare error in both cameras + Vector2 px_noise(1.0, 2.0); // px perturbation vertically and horizontally + Vector2 measured_px = px + px_noise; + Vector2 measured_px_calibrated = sharedK->calibrate(measured_px); + Unit3 measured_u = + Unit3(measured_px_calibrated[0], measured_px_calibrated[1], 1.0); + Unit3 expected_measured_u = + Unit3(px_noise[0] / sharedK->fx(), px_noise[1] / sharedK->fy(), 1.0); + EXPECT(assert_equal(expected_measured_u, measured_u, 1e-7)); + + Vector2 actualErrorPinhole = + pinholeCamera.reprojectionError(landmarkL, measured_px); + Vector2 expectedErrorPinhole = Vector2(-px_noise[0], -px_noise[1]); + EXPECT(assert_equal(expectedErrorPinhole, actualErrorPinhole, + 1e-7)); //- sign due to definition of error + + Vector2 actualErrorSpherical = + sphericalCamera.reprojectionError(landmarkL, measured_u); + // expectedError: not easy to calculate, since it involves the unit3 basis + Vector2 expectedErrorSpherical(-0.00360842, 0.00180419); + EXPECT(assert_equal(expectedErrorSpherical, actualErrorSpherical, 1e-7)); +} + //****************************************************************************** int main() { TestResult tr; diff --git a/gtsam/geometry/triangulation.cpp b/gtsam/geometry/triangulation.cpp index a5d2e04cd4..026afef246 100644 --- a/gtsam/geometry/triangulation.cpp +++ b/gtsam/geometry/triangulation.cpp @@ -53,15 +53,57 @@ Vector4 triangulateHomogeneousDLT( return v; } -Point3 triangulateDLT(const std::vector>& projection_matrices, - const Point2Vector& measurements, double rank_tol) { +Vector4 triangulateHomogeneousDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol) { + + // number of cameras + size_t m = projection_matrices.size(); + + // Allocate DLT matrix + Matrix A = Matrix::Zero(m * 2, 4); + + for (size_t i = 0; i < m; i++) { + size_t row = i * 2; + const Matrix34& projection = projection_matrices.at(i); + const Point3& p = measurements.at(i).point3(); // to get access to x,y,z of the bearing vector + + // build system of equations + A.row(row) = p.x() * projection.row(2) - p.z() * projection.row(0); + A.row(row + 1) = p.y() * projection.row(2) - p.z() * projection.row(1); + } + int rank; + double error; + Vector v; + boost::tie(rank, error, v) = DLT(A, rank_tol); + + if (rank < 3) + throw(TriangulationUnderconstrainedException()); + + return v; +} - Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, rank_tol); +Point3 triangulateDLT( + const std::vector>& projection_matrices, + const Point2Vector& measurements, double rank_tol) { + Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, + rank_tol); // Create 3D point from homogeneous coordinates return Point3(v.head<3>() / v[3]); } +Point3 triangulateDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol) { + + // contrary to previous triangulateDLT, this is now taking Unit3 inputs + Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, + rank_tol); + // Create 3D point from homogeneous coordinates + return Point3(v.head<3>() / v[3]); +} + /// /** * Optimize for triangulation @@ -71,7 +113,7 @@ Point3 triangulateDLT(const std::vector #include #include +#include #include #include #include @@ -59,6 +60,18 @@ GTSAM_EXPORT Vector4 triangulateHomogeneousDLT( const std::vector>& projection_matrices, const Point2Vector& measurements, double rank_tol = 1e-9); +/** + * Same math as Hartley and Zisserman, 2nd Ed., page 312, but with unit-norm bearing vectors + * (contrarily to pinhole projection, the z entry is not assumed to be 1 as in Hartley and Zisserman) + * @param projection_matrices Projection matrices (K*P^-1) + * @param measurements Unit3 bearing measurements + * @param rank_tol SVD rank tolerance + * @return Triangulated point, in homogeneous coordinates + */ +GTSAM_EXPORT Vector4 triangulateHomogeneousDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol = 1e-9); + /** * DLT triangulation: See Hartley and Zisserman, 2nd Ed., page 312 * @param projection_matrices Projection matrices (K*P^-1) @@ -71,6 +84,14 @@ GTSAM_EXPORT Point3 triangulateDLT( const Point2Vector& measurements, double rank_tol = 1e-9); +/** + * overload of previous function to work with Unit3 (projected to canonical camera) + */ +GTSAM_EXPORT Point3 triangulateDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, + double rank_tol = 1e-9); + /** * Create a factor graph with projection factors from poses and one calibration * @param poses Camera poses @@ -180,26 +201,27 @@ Point3 triangulateNonlinear( return optimize(graph, values, Symbol('p', 0)); } -/** - * Create a 3*4 camera projection matrix from calibration and pose. - * Functor for partial application on calibration - * @param pose The camera pose - * @param cal The calibration - * @return Returns a Matrix34 - */ -template -struct CameraProjectionMatrix { - CameraProjectionMatrix(const CALIBRATION& calibration) : - K_(calibration.K()) { +template +std::vector> +projectionMatricesFromCameras(const CameraSet &cameras) { + std::vector> projection_matrices; + for (const CAMERA &camera: cameras) { + projection_matrices.push_back(camera.cameraProjectionMatrix()); } - Matrix34 operator()(const Pose3& pose) const { - return K_ * (pose.inverse().matrix()).block<3, 4>(0, 0); + return projection_matrices; +} + +// overload, assuming pinholePose +template +std::vector> projectionMatricesFromPoses( + const std::vector &poses, boost::shared_ptr sharedCal) { + std::vector> projection_matrices; + for (size_t i = 0; i < poses.size(); i++) { + PinholePose camera(poses.at(i), sharedCal); + projection_matrices.push_back(camera.cameraProjectionMatrix()); } -private: - const Matrix3 K_; -public: - GTSAM_MAKE_ALIGNED_OPERATOR_NEW -}; + return projection_matrices; +} /** * Function to triangulate 3D landmark point from an arbitrary number @@ -224,10 +246,7 @@ Point3 triangulatePoint3(const std::vector& poses, throw(TriangulationUnderconstrainedException()); // construct projection matrices from poses & calibration - std::vector> projection_matrices; - CameraProjectionMatrix createP(*sharedCal); // partially apply - for(const Pose3& pose: poses) - projection_matrices.push_back(createP(pose)); + auto projection_matrices = projectionMatricesFromPoses(poses, sharedCal); // Triangulate linearly Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); @@ -274,11 +293,7 @@ Point3 triangulatePoint3( throw(TriangulationUnderconstrainedException()); // construct projection matrices from poses & calibration - std::vector> projection_matrices; - for(const CAMERA& camera: cameras) - projection_matrices.push_back( - CameraProjectionMatrix(camera.calibration())( - camera.pose())); + auto projection_matrices = projectionMatricesFromCameras(cameras); Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); // The n refine using non-linear optimization @@ -474,8 +489,8 @@ TriangulationResult triangulateSafe(const CameraSet& cameras, #endif // Check reprojection error if (params.dynamicOutlierRejectionThreshold > 0) { - const Point2& zi = measured.at(i); - Point2 reprojectionError(camera.project(point) - zi); + const typename CAMERA::Measurement& zi = measured.at(i); + Point2 reprojectionError = camera.reprojectionError(point, zi); maxReprojError = std::max(maxReprojError, reprojectionError.norm()); } i += 1; @@ -503,6 +518,6 @@ using CameraSetCal3Bundler = CameraSet>; using CameraSetCal3_S2 = CameraSet>; using CameraSetCal3Fisheye = CameraSet>; using CameraSetCal3Unified = CameraSet>; - +using CameraSetSpherical = CameraSet; } // \namespace gtsam diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 67c3278a36..d4e959c3dd 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -39,9 +39,6 @@ class KeyList { void remove(size_t key); void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a FastSet @@ -67,9 +64,6 @@ class KeySet { bool count(size_t key) const; // returns true if value exists void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a vector @@ -91,9 +85,6 @@ class KeyVector { void push_back(size_t key) const; void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a FastMap @@ -165,6 +156,7 @@ gtsam::Values allPose2s(gtsam::Values& values); Matrix extractPose2(const gtsam::Values& values); gtsam::Values allPose3s(gtsam::Values& values); Matrix extractPose3(const gtsam::Values& values); +Matrix extractVectors(const gtsam::Values& values, char c); void perturbPoint2(gtsam::Values& values, double sigma, int seed = 42u); void perturbPose2(gtsam::Values& values, double sigmaT, double sigmaR, int seed = 42u); diff --git a/gtsam/inference/BayesNet-inst.h b/gtsam/inference/BayesNet-inst.h index a737622585..be34b2928f 100644 --- a/gtsam/inference/BayesNet-inst.h +++ b/gtsam/inference/BayesNet-inst.h @@ -35,21 +35,39 @@ void BayesNet::print( /* ************************************************************************* */ template -void BayesNet::saveGraph(const std::string& s, - const KeyFormatter& keyFormatter) const { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; - - for (auto conditional : boost::adaptors::reverse(*this)) { - typename CONDITIONAL::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - typename CONDITIONAL::Parents parents = conditional->parents(); - for (Key p : parents) - of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl; +void BayesNet::dot(std::ostream& os, + const KeyFormatter& keyFormatter) const { + os << "digraph G{\n"; + + for (auto conditional : *this) { + auto frontals = conditional->frontals(); + const Key me = frontals.front(); + auto parents = conditional->parents(); + for (const Key& p : parents) + os << keyFormatter(p) << "->" << keyFormatter(me) << "\n"; } - of << "}"; + os << "}"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string BayesNet::dot(const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, keyFormatter); + return ss.str(); +} + +/* ************************************************************************* */ +template +void BayesNet::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter); of.close(); } +/* ************************************************************************* */ + } // namespace gtsam diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 938278d5ad..f987ad51be 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -64,11 +64,21 @@ namespace gtsam { /// @} - /// @name Standard Interface + /// @name Graph Display /// @{ - void saveGraph(const std::string& s, + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} }; } diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 5b53a57193..9b937fefb5 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -63,20 +63,40 @@ namespace gtsam { } /* ************************************************************************* */ - template - void BayesTree::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const { - if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); - std::ofstream of(s.c_str()); - of<< "digraph G{\n"; - for(const sharedClique& root: roots_) - saveGraph(of, root, keyFormatter); - of<<"}"; + template + void BayesTree::dot(std::ostream& os, + const KeyFormatter& keyFormatter) const { + if (roots_.empty()) + throw std::invalid_argument( + "the root of Bayes tree has not been initialized!"); + os << "digraph G{\n"; + for (const sharedClique& root : roots_) dot(os, root, keyFormatter); + os << "}"; + std::flush(os); + } + + /* ************************************************************************* */ + template + std::string BayesTree::dot(const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, keyFormatter); + return ss.str(); + } + + /* ************************************************************************* */ + template + void BayesTree::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter); of.close(); } /* ************************************************************************* */ - template - void BayesTree::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const { + template + void BayesTree::dot(std::ostream& s, sharedClique clique, + const KeyFormatter& indexFormatter, + int parentnum) const { static int num = 0; bool first = true; std::stringstream out; @@ -107,7 +127,7 @@ namespace gtsam { for (sharedClique c : clique->children) { num++; - saveGraph(s, c, indexFormatter, parentnum); + dot(s, c, indexFormatter, parentnum); } } diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index cc003d8dcb..68a45a014a 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -182,13 +182,20 @@ namespace gtsam { */ sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; - /** - * Read only with side effects - */ + /// @name Graph Display + /// @{ + + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /** saves the Tree to a text file in GraphViz format */ - void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} /// @name Advanced Interface /// @{ @@ -236,8 +243,8 @@ namespace gtsam { protected: /** private helper method for saving the Tree to a text file in GraphViz format */ - void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, - int parentnum = 0) const; + void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, + int parentnum = 0) const; /** Gather data on a single clique */ void getCliqueData(sharedClique clique, BayesTreeCliqueData* stats) const; @@ -249,7 +256,7 @@ namespace gtsam { void fillNodesIndex(const sharedClique& subtree); // Friend JunctionTree because it directly fills roots and nodes index. - template friend class EliminatableClusterTree; + template friend class EliminatableClusterTree; private: /** Serialization function */ diff --git a/gtsam/inference/DotWriter.cpp b/gtsam/inference/DotWriter.cpp new file mode 100644 index 0000000000..fb3ea05054 --- /dev/null +++ b/gtsam/inference/DotWriter.cpp @@ -0,0 +1,93 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DotWriter.cpp + * @brief Graphviz formatting for factor graphs. + * @author Frank Dellaert + * @date December, 2021 + */ + +#include +#include + +#include + +using namespace std; + +namespace gtsam { + +void DotWriter::writePreamble(ostream* os) const { + *os << "graph {\n"; + *os << " size=\"" << figureWidthInches << "," << figureHeightInches + << "\";\n\n"; +} + +void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + ostream* os) { + // Label the node with the label from the KeyFormatter + *os << " var" << key << "[label=\"" << keyFormatter(key) << "\""; + if (position) { + *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; + } + *os << "];\n"; +} + +void DotWriter::DrawFactor(size_t i, const boost::optional& position, + ostream* os) { + *os << " factor" << i << "[label=\"\", shape=point"; + if (position) { + *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; + } + *os << "];\n"; +} + +void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) { + *os << " var" << key1 << "--" + << "var" << key2 << ";\n"; +} + +void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) { + *os << " var" << key << "--" + << "factor" << i << ";\n"; +} + +void DotWriter::processFactor(size_t i, const KeyVector& keys, + const boost::optional& position, + ostream* os) const { + if (plotFactorPoints) { + if (binaryEdges && keys.size() == 2) { + ConnectVariables(keys[0], keys[1], os); + } else { + // Create dot for the factor. + DrawFactor(i, position, os); + + // Make factor-variable connections + if (connectKeysToFactor) { + for (Key key : keys) { + ConnectVariableFactor(key, i, os); + } + } + } + } else { + // just connect variables in a clique + for (Key key1 : keys) { + for (Key key2 : keys) { + if (key2 > key1) { + ConnectVariables(key1, key2, os); + } + } + } + } +} + +} // namespace gtsam diff --git a/gtsam/inference/DotWriter.h b/gtsam/inference/DotWriter.h new file mode 100644 index 0000000000..bd36da496c --- /dev/null +++ b/gtsam/inference/DotWriter.h @@ -0,0 +1,72 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DotWriter.h + * @brief Graphviz formatter + * @author Frank Dellaert + * @date December, 2021 + */ + +#pragma once + +#include +#include +#include + +#include + +namespace gtsam { + +/// Graphviz formatter. +struct GTSAM_EXPORT DotWriter { + double figureWidthInches; ///< The figure width on paper in inches + double figureHeightInches; ///< The figure height on paper in inches + bool plotFactorPoints; ///< Plots each factor as a dot between the variables + bool connectKeysToFactor; ///< Draw a line from each key within a factor to + ///< the dot of the factor + bool binaryEdges; ///< just use non-dotted edges for binary factors + + explicit DotWriter(double figureWidthInches = 5, + double figureHeightInches = 5, + bool plotFactorPoints = true, + bool connectKeysToFactor = true, bool binaryEdges = true) + : figureWidthInches(figureWidthInches), + figureHeightInches(figureHeightInches), + plotFactorPoints(plotFactorPoints), + connectKeysToFactor(connectKeysToFactor), + binaryEdges(binaryEdges) {} + + /// Write out preamble, including size. + void writePreamble(std::ostream* os) const; + + /// Create a variable dot fragment. + static void DrawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + std::ostream* os); + + /// Create factor dot. + static void DrawFactor(size_t i, const boost::optional& position, + std::ostream* os); + + /// Connect two variables. + static void ConnectVariables(Key key1, Key key2, std::ostream* os); + + /// Connect variable and factor. + static void ConnectVariableFactor(Key key, size_t i, std::ostream* os); + + /// Draw a single factor, specified by its index i and its variable keys. + void processFactor(size_t i, const KeyVector& keys, + const boost::optional& position, + std::ostream* os) const; +}; + +} // namespace gtsam diff --git a/gtsam/inference/FactorGraph-inst.h b/gtsam/inference/FactorGraph-inst.h index 166ae41f41..058075f2d5 100644 --- a/gtsam/inference/FactorGraph-inst.h +++ b/gtsam/inference/FactorGraph-inst.h @@ -26,6 +26,7 @@ #include #include #include // for cout :-( +#include #include #include @@ -125,4 +126,49 @@ FactorIndices FactorGraph::add_factors(const CONTAINER& factors, return newFactorIndices; } +/* ************************************************************************* */ +template +void FactorGraph::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + writer.writePreamble(&os); + + // Create nodes for each variable in the graph + for (Key key : keys()) { + writer.DrawVariable(key, keyFormatter, boost::none, &os); + } + os << "\n"; + + // Create factors and variable connections + for (size_t i = 0; i < size(); ++i) { + const auto& factor = at(i); + if (factor) { + const KeyVector& factorKeys = factor->keys(); + writer.processFactor(i, factorKeys, boost::none, &os); + } + } + + os << "}\n"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string FactorGraph::dot(const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::stringstream ss; + dot(ss, keyFormatter, writer); + return ss.str(); +} + +/* ************************************************************************* */ +template +void FactorGraph::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter, writer); + of.close(); +} + } // namespace gtsam diff --git a/gtsam/inference/FactorGraph.h b/gtsam/inference/FactorGraph.h index e337e3249f..9c0f10f9a5 100644 --- a/gtsam/inference/FactorGraph.h +++ b/gtsam/inference/FactorGraph.h @@ -22,9 +22,10 @@ #pragma once +#include +#include #include #include -#include #include // for Eigen::aligned_allocator @@ -36,6 +37,7 @@ #include #include #include +#include namespace gtsam { /// Define collection type: @@ -371,6 +373,24 @@ class FactorGraph { return factors_.erase(first, last); } + /// @} + /// @name Graph Display + /// @{ + + /// Output to graphviz format, stream version. + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + + /// Output to graphviz format string. + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/linear/Errors.cpp b/gtsam/linear/Errors.cpp index 4b30dcc087..41c6c3d09e 100644 --- a/gtsam/linear/Errors.cpp +++ b/gtsam/linear/Errors.cpp @@ -110,7 +110,6 @@ double dot(const Errors& a, const Errors& b) { } /* ************************************************************************* */ -template<> void axpy(double alpha, const Errors& x, Errors& y) { Errors::const_iterator it = x.begin(); for(Vector& yi: y) diff --git a/gtsam/linear/Errors.h b/gtsam/linear/Errors.h index e8ba7344ed..f6e147084b 100644 --- a/gtsam/linear/Errors.h +++ b/gtsam/linear/Errors.h @@ -65,7 +65,6 @@ namespace gtsam { /** * BLAS level 2 style */ - template <> GTSAM_EXPORT void axpy(double alpha, const Errors& x, Errors& y); /** print with optional string */ diff --git a/gtsam/linear/GaussianFactorGraph.cpp b/gtsam/linear/GaussianFactorGraph.cpp index 664aeff6d0..72eb107d09 100644 --- a/gtsam/linear/GaussianFactorGraph.cpp +++ b/gtsam/linear/GaussianFactorGraph.cpp @@ -379,7 +379,7 @@ namespace gtsam { gttic(Compute_minimizing_step_size); // Compute minimizing step size - double step = -gradientSqNorm / dot(Rg, Rg); + double step = -gradientSqNorm / gtsam::dot(Rg, Rg); gttoc(Compute_minimizing_step_size); gttic(Compute_point); diff --git a/gtsam/linear/linear.i b/gtsam/linear/linear.i index d882cb38b5..7b1ce550f0 100644 --- a/gtsam/linear/linear.i +++ b/gtsam/linear/linear.i @@ -34,9 +34,6 @@ virtual class Gaussian : gtsam::noiseModel::Base { // enabling serialization functionality void serializable() const; - - // enable pickling in python - void pickle() const; }; virtual class Diagonal : gtsam::noiseModel::Gaussian { @@ -52,9 +49,6 @@ virtual class Diagonal : gtsam::noiseModel::Gaussian { // enabling serialization functionality void serializable() const; - - // enable pickling in python - void pickle() const; }; virtual class Constrained : gtsam::noiseModel::Diagonal { @@ -72,9 +66,6 @@ virtual class Constrained : gtsam::noiseModel::Diagonal { // enabling serialization functionality void serializable() const; - - // enable pickling in python - void pickle() const; }; virtual class Isotropic : gtsam::noiseModel::Diagonal { @@ -87,9 +78,6 @@ virtual class Isotropic : gtsam::noiseModel::Diagonal { // enabling serialization functionality void serializable() const; - - // enable pickling in python - void pickle() const; }; virtual class Unit : gtsam::noiseModel::Isotropic { @@ -97,9 +85,6 @@ virtual class Unit : gtsam::noiseModel::Isotropic { // enabling serialization functionality void serializable() const; - - // enable pickling in python - void pickle() const; }; namespace mEstimator { @@ -270,9 +255,6 @@ class VectorValues { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -344,9 +326,6 @@ virtual class JacobianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -379,9 +358,6 @@ virtual class HessianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -463,9 +439,6 @@ class GaussianFactorGraph { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include diff --git a/gtsam/navigation/BarometricFactor.cpp b/gtsam/navigation/BarometricFactor.cpp new file mode 100644 index 0000000000..2f0ff7436d --- /dev/null +++ b/gtsam/navigation/BarometricFactor.cpp @@ -0,0 +1,55 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BarometricFactor.cpp + * @author Peter Milani + * @brief Implementation file for Barometric factor + * @date December 16, 2021 + **/ + +#include "BarometricFactor.h" + +using namespace std; + +namespace gtsam { + +//*************************************************************************** +void BarometricFactor::print(const string& s, + const KeyFormatter& keyFormatter) const { + cout << (s.empty() ? "" : s + " ") << "Barometric Factor on " + << keyFormatter(key1()) << "Barometric Bias on " + << keyFormatter(key2()) << "\n"; + + cout << " Baro measurement: " << nT_ << "\n"; + noiseModel_->print(" noise model: "); +} + +//*************************************************************************** +bool BarometricFactor::equals(const NonlinearFactor& expected, + double tol) const { + const This* e = dynamic_cast(&expected); + return e != nullptr && Base::equals(*e, tol) && + traits::Equals(nT_, e->nT_, tol); +} + +//*************************************************************************** +Vector BarometricFactor::evaluateError(const Pose3& p, const double& bias, + boost::optional H, + boost::optional H2) const { + Matrix tH; + Vector ret = (Vector(1) << (p.translation(tH).z() + bias - nT_)).finished(); + if (H) (*H) = tH.block<1, 6>(2, 0); + if (H2) (*H2) = (Matrix(1, 1) << 1.0).finished(); + return ret; +} + +} // namespace gtsam diff --git a/gtsam/navigation/BarometricFactor.h b/gtsam/navigation/BarometricFactor.h new file mode 100644 index 0000000000..e7bf6f9989 --- /dev/null +++ b/gtsam/navigation/BarometricFactor.h @@ -0,0 +1,109 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BarometricFactor.h + * @author Peter Milani + * @brief Header file for Barometric factor + * @date December 16, 2021 + **/ +#pragma once + +#include +#include +#include + +namespace gtsam { + +/** + * Prior on height in a cartesian frame. + * Receive barometric pressure in kilopascals + * Model with a slowly moving bias to capture differences + * between the height and the standard atmosphere + * https://www.grc.nasa.gov/www/k-12/airplane/atmosmet.html + * @addtogroup Navigation + */ +class GTSAM_EXPORT BarometricFactor : public NoiseModelFactor2 { + private: + typedef NoiseModelFactor2 Base; + + double nT_; ///< Height Measurement based on a standard atmosphere + + public: + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; + + /// Typedef to this class + typedef BarometricFactor This; + + /** default constructor - only use for serialization */ + BarometricFactor() : nT_(0) {} + + ~BarometricFactor() override {} + + /** + * @brief Constructor from a measurement of pressure in KPa. + * @param key of the Pose3 variable that will be constrained + * @param key of the barometric bias that will be constrained + * @param baroIn measurement in KPa + * @param model Gaussian noise model 1 dimension + */ + BarometricFactor(Key key, Key baroKey, const double& baroIn, + const SharedNoiseModel& model) + : Base(model, key, baroKey), nT_(heightOut(baroIn)) {} + + /// @return a deep copy of this factor + gtsam::NonlinearFactor::shared_ptr clone() const override { + return boost::static_pointer_cast( + gtsam::NonlinearFactor::shared_ptr(new This(*this))); + } + + /// print + void print( + const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + + /// equals + bool equals(const NonlinearFactor& expected, + double tol = 1e-9) const override; + + /// vector of errors + Vector evaluateError( + const Pose3& p, const double& b, + boost::optional H = boost::none, + boost::optional H2 = boost::none) const override; + + inline const double& measurementIn() const { return nT_; } + + inline double heightOut(double n) const { + // From https://www.grc.nasa.gov/www/k-12/airplane/atmosmet.html + return (std::pow(n / 101.29, 1. / 5.256) * 288.08 - 273.1 - 15.04) / + -0.00649; + }; + + inline double baroOut(const double& meters) { + double temp = 15.04 - 0.00649 * meters; + return 101.29 * std::pow(((temp + 273.1) / 288.08), 5.256); + }; + + private: + /// Serialization function + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "NoiseModelFactor1", + boost::serialization::base_object(*this)); + ar& BOOST_SERIALIZATION_NVP(nT_); + } +}; + +} // namespace gtsam diff --git a/gtsam/navigation/navigation.i b/gtsam/navigation/navigation.i index 1f9ffcf2e5..c2a3bcdb42 100644 --- a/gtsam/navigation/navigation.i +++ b/gtsam/navigation/navigation.i @@ -44,9 +44,6 @@ class ConstantBias { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; }///\namespace imuBias @@ -73,9 +70,6 @@ class NavState { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -121,9 +115,6 @@ virtual class PreintegrationParams : gtsam::PreintegratedRotationParams { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -156,9 +147,6 @@ class PreintegratedImuMeasurements { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; virtual class ImuFactor: gtsam::NonlinearFactor { diff --git a/gtsam/navigation/tests/testBarometricFactor.cpp b/gtsam/navigation/tests/testBarometricFactor.cpp new file mode 100644 index 0000000000..47f4824c11 --- /dev/null +++ b/gtsam/navigation/tests/testBarometricFactor.cpp @@ -0,0 +1,129 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testBarometricFactor.cpp + * @brief Unit test for BarometricFactor + * @author Peter Milani + * @date 16 Dec, 2021 + */ + +#include +#include +#include +#include + +#include + +using namespace std::placeholders; +using namespace std; +using namespace gtsam; + +// ************************************************************************* +namespace example {} + +double metersToBaro(const double& meters) { + double temp = 15.04 - 0.00649 * meters; + return 101.29 * std::pow(((temp + 273.1) / 288.08), 5.256); +} + +// ************************************************************************* +TEST(BarometricFactor, Constructor) { + using namespace example; + + // meters to barometric. + + double baroMeasurement = metersToBaro(10.); + + // Factor + Key key(1); + Key key2(2); + SharedNoiseModel model = noiseModel::Isotropic::Sigma(1, 0.25); + BarometricFactor factor(key, key2, baroMeasurement, model); + + // Create a linearization point at zero error + Pose3 T(Rot3::RzRyRx(0., 0., 0.), Point3(0., 0., 10.)); + double baroBias = 0.; + Vector1 zero; + zero << 0.; + EXPECT(assert_equal(zero, factor.evaluateError(T, baroBias), 1e-5)); + + // Calculate numerical derivatives + Matrix expectedH = numericalDerivative21( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + Matrix expectedH2 = numericalDerivative22( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + // Use the factor to calculate the derivative + Matrix actualH, actualH2; + factor.evaluateError(T, baroBias, actualH, actualH2); + + // Verify we get the expected error + EXPECT(assert_equal(expectedH, actualH, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2, 1e-8)); +} + +// ************************************************************************* + +//*************************************************************************** +TEST(BarometricFactor, nonZero) { + using namespace example; + + // meters to barometric. + + double baroMeasurement = metersToBaro(10.); + + // Factor + Key key(1); + Key key2(2); + SharedNoiseModel model = noiseModel::Isotropic::Sigma(1, 0.25); + BarometricFactor factor(key, key2, baroMeasurement, model); + + Pose3 T(Rot3::RzRyRx(0.5, 1., 1.), Point3(20., 30., 1.)); + double baroBias = 5.; + + // Calculate numerical derivatives + Matrix expectedH = numericalDerivative21( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + Matrix expectedH2 = numericalDerivative22( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + // Use the factor to calculate the derivative and the error + Matrix actualH, actualH2; + Vector error = factor.evaluateError(T, baroBias, actualH, actualH2); + Vector actual = (Vector(1) << -4.0).finished(); + + // Verify we get the expected error + EXPECT(assert_equal(expectedH, actualH, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2, 1e-8)); + EXPECT(assert_equal(error, actual, 1e-8)); +} + +// ************************************************************************* +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +// ************************************************************************* diff --git a/gtsam/nonlinear/GraphvizFormatting.cpp b/gtsam/nonlinear/GraphvizFormatting.cpp new file mode 100644 index 0000000000..c37f07c8a8 --- /dev/null +++ b/gtsam/nonlinear/GraphvizFormatting.cpp @@ -0,0 +1,136 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GraphvizFormatting.cpp + * @brief Graphviz formatter for NonlinearFactorGraph + * @author Frank Dellaert + * @date December, 2021 + */ + +#include +#include + +// TODO(frank): nonlinear should not depend on geometry: +#include +#include + +#include + +namespace gtsam { + +Vector2 GraphvizFormatting::findBounds(const Values& values, + const KeySet& keys) const { + Vector2 min; + min.x() = std::numeric_limits::infinity(); + min.y() = std::numeric_limits::infinity(); + for (const Key& key : keys) { + if (values.exists(key)) { + boost::optional xy = operator()(values.at(key)); + if (xy) { + if (xy->x() < min.x()) min.x() = xy->x(); + if (xy->y() < min.y()) min.y() = xy->y(); + } + } + } + return min; +} + +boost::optional GraphvizFormatting::operator()( + const Value& value) const { + Vector3 t; + if (const GenericValue* p = + dynamic_cast*>(&value)) { + t << p->value().x(), p->value().y(), 0; + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t << p->value().x(), p->value().y(), 0; + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t = p->value().translation(); + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t = p->value(); + } else { + return boost::none; + } + double x, y; + switch (paperHorizontalAxis) { + case X: + x = t.x(); + break; + case Y: + x = t.y(); + break; + case Z: + x = t.z(); + break; + case NEGX: + x = -t.x(); + break; + case NEGY: + x = -t.y(); + break; + case NEGZ: + x = -t.z(); + break; + default: + throw std::runtime_error("Invalid enum value"); + } + switch (paperVerticalAxis) { + case X: + y = t.x(); + break; + case Y: + y = t.y(); + break; + case Z: + y = t.z(); + break; + case NEGX: + y = -t.x(); + break; + case NEGY: + y = -t.y(); + break; + case NEGZ: + y = -t.z(); + break; + default: + throw std::runtime_error("Invalid enum value"); + } + return Vector2(x, y); +} + +// Return affinely transformed variable position if it exists. +boost::optional GraphvizFormatting::variablePos(const Values& values, + const Vector2& min, + Key key) const { + if (!values.exists(key)) return boost::none; + boost::optional xy = operator()(values.at(key)); + if (xy) { + xy->x() = scale * (xy->x() - min.x()); + xy->y() = scale * (xy->y() - min.y()); + } + return xy; +} + +// Return affinely transformed factor position if it exists. +boost::optional GraphvizFormatting::factorPos(const Vector2& min, + size_t i) const { + if (factorPositions.size() == 0) return boost::none; + auto it = factorPositions.find(i); + if (it == factorPositions.end()) return boost::none; + auto pos = it->second; + return Vector2(scale * (pos.x() - min.x()), scale * (pos.y() - min.y())); +} + +} // namespace gtsam diff --git a/gtsam/nonlinear/GraphvizFormatting.h b/gtsam/nonlinear/GraphvizFormatting.h new file mode 100644 index 0000000000..c36b09a8fc --- /dev/null +++ b/gtsam/nonlinear/GraphvizFormatting.h @@ -0,0 +1,69 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GraphvizFormatting.h + * @brief Graphviz formatter for NonlinearFactorGraph + * @author Frank Dellaert + * @date December, 2021 + */ + +#pragma once + +#include + +namespace gtsam { + +class Values; +class Value; + +/** + * Formatting options and functions for saving a NonlinearFactorGraph instance + * in GraphViz format. + */ +struct GTSAM_EXPORT GraphvizFormatting : public DotWriter { + /// World axes to be assigned to paper axes + enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; + + Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal + ///< paper axis + Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper + ///< axis + double scale; ///< Scale all positions to reduce / increase density + bool mergeSimilarFactors; ///< Merge multiple factors that have the same + ///< connectivity + + /// (optional for each factor) Manually specify factor "dot" positions: + std::map factorPositions; + + /// Default constructor sets up robot coordinates. Paper horizontal is robot + /// Y, paper vertical is robot X. Default figure size of 5x5 in. + GraphvizFormatting() + : paperHorizontalAxis(Y), + paperVerticalAxis(X), + scale(1), + mergeSimilarFactors(false) {} + + // Find bounds + Vector2 findBounds(const Values& values, const KeySet& keys) const; + + /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 + boost::optional operator()(const Value& value) const; + + /// Return affinely transformed variable position if it exists. + boost::optional variablePos(const Values& values, const Vector2& min, + Key key) const; + + /// Return affinely transformed factor position if it exists. + boost::optional factorPos(const Vector2& min, size_t i) const; +}; + +} // namespace gtsam diff --git a/gtsam/nonlinear/NonlinearFactor.h b/gtsam/nonlinear/NonlinearFactor.h index 7fafd95dfa..38d831e152 100644 --- a/gtsam/nonlinear/NonlinearFactor.h +++ b/gtsam/nonlinear/NonlinearFactor.h @@ -282,7 +282,7 @@ class GTSAM_EXPORT NoiseModelFactor: public NonlinearFactor { * which are objects in non-linear manifolds (Lie groups). */ template -class NoiseModelFactor1: public NoiseModelFactor { +class GTSAM_EXPORT NoiseModelFactor1: public NoiseModelFactor { public: @@ -366,7 +366,7 @@ class NoiseModelFactor1: public NoiseModelFactor { /** A convenient base class for creating your own NoiseModelFactor with 2 * variables. To derive from this class, implement evaluateError(). */ template -class NoiseModelFactor2: public NoiseModelFactor { +class GTSAM_EXPORT NoiseModelFactor2: public NoiseModelFactor { public: @@ -441,7 +441,7 @@ class NoiseModelFactor2: public NoiseModelFactor { /** A convenient base class for creating your own NoiseModelFactor with 3 * variables. To derive from this class, implement evaluateError(). */ template -class NoiseModelFactor3: public NoiseModelFactor { +class GTSAM_EXPORT NoiseModelFactor3: public NoiseModelFactor { public: @@ -518,7 +518,7 @@ class NoiseModelFactor3: public NoiseModelFactor { /** A convenient base class for creating your own NoiseModelFactor with 4 * variables. To derive from this class, implement evaluateError(). */ template -class NoiseModelFactor4: public NoiseModelFactor { +class GTSAM_EXPORT NoiseModelFactor4: public NoiseModelFactor { public: @@ -599,7 +599,7 @@ class NoiseModelFactor4: public NoiseModelFactor { /** A convenient base class for creating your own NoiseModelFactor with 5 * variables. To derive from this class, implement evaluateError(). */ template -class NoiseModelFactor5: public NoiseModelFactor { +class GTSAM_EXPORT NoiseModelFactor5: public NoiseModelFactor { public: @@ -684,7 +684,7 @@ class NoiseModelFactor5: public NoiseModelFactor { /** A convenient base class for creating your own NoiseModelFactor with 6 * variables. To derive from this class, implement evaluateError(). */ template -class NoiseModelFactor6: public NoiseModelFactor { +class GTSAM_EXPORT NoiseModelFactor6: public NoiseModelFactor { public: diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index 8e4cf277c2..0d1ed31487 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -35,7 +35,6 @@ #include #include -#include using namespace std; @@ -91,89 +90,25 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol) } /* ************************************************************************* */ -void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values, - const GraphvizFormatting& formatting, - const KeyFormatter& keyFormatter) const -{ - stm << "graph {\n"; - stm << " size=\"" << formatting.figureWidthInches << "," << - formatting.figureHeightInches << "\";\n\n"; +void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + writer.writePreamble(&os); + // Find bounds (imperative) KeySet keys = this->keys(); - - // Local utility function to extract x and y coordinates - struct { boost::optional operator()( - const Value& value, const GraphvizFormatting& graphvizFormatting) - { - Vector3 t; - if (const GenericValue* p = dynamic_cast*>(&value)) { - t << p->value().x(), p->value().y(), 0; - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t << p->value().x(), p->value().y(), 0; - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t = p->value().translation(); - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t = p->value(); - } else { - return boost::none; - } - double x, y; - switch (graphvizFormatting.paperHorizontalAxis) { - case GraphvizFormatting::X: x = t.x(); break; - case GraphvizFormatting::Y: x = t.y(); break; - case GraphvizFormatting::Z: x = t.z(); break; - case GraphvizFormatting::NEGX: x = -t.x(); break; - case GraphvizFormatting::NEGY: x = -t.y(); break; - case GraphvizFormatting::NEGZ: x = -t.z(); break; - default: throw std::runtime_error("Invalid enum value"); - } - switch (graphvizFormatting.paperVerticalAxis) { - case GraphvizFormatting::X: y = t.x(); break; - case GraphvizFormatting::Y: y = t.y(); break; - case GraphvizFormatting::Z: y = t.z(); break; - case GraphvizFormatting::NEGX: y = -t.x(); break; - case GraphvizFormatting::NEGY: y = -t.y(); break; - case GraphvizFormatting::NEGZ: y = -t.z(); break; - default: throw std::runtime_error("Invalid enum value"); - } - return Point2(x,y); - }} getXY; - - // Find bounds - double minX = numeric_limits::infinity(), maxX = -numeric_limits::infinity(); - double minY = numeric_limits::infinity(), maxY = -numeric_limits::infinity(); - for (const Key& key : keys) { - if (values.exists(key)) { - boost::optional xy = getXY(values.at(key), formatting); - if(xy) { - if(xy->x() < minX) - minX = xy->x(); - if(xy->x() > maxX) - maxX = xy->x(); - if(xy->y() < minY) - minY = xy->y(); - if(xy->y() > maxY) - maxY = xy->y(); - } - } - } + Vector2 min = writer.findBounds(values, keys); // Create nodes for each variable in the graph - for(Key key: keys){ - // Label the node with the label from the KeyFormatter - stm << " var" << key << "[label=\"" << keyFormatter(key) << "\""; - if(values.exists(key)) { - boost::optional xy = getXY(values.at(key), formatting); - if(xy) - stm << ", pos=\"" << formatting.scale*(xy->x() - minX) << "," << formatting.scale*(xy->y() - minY) << "!\""; - } - stm << "];\n"; + for (Key key : keys) { + auto position = writer.variablePos(values, min, key); + writer.DrawVariable(key, keyFormatter, position, &os); } - stm << "\n"; + os << "\n"; - if (formatting.mergeSimilarFactors) { + if (writer.mergeSimilarFactors) { // Remove duplicate factors - std::set structure; + std::set structure; for (const sharedFactor& factor : factors_) { if (factor) { KeyVector factorKeys = factor->keys(); @@ -184,86 +119,40 @@ void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values, // Create factors and variable connections size_t i = 0; - for(const KeyVector& factorKeys: structure){ - // Make each factor a dot - stm << " factor" << i << "[label=\"\", shape=point"; - { - map::const_iterator pos = formatting.factorPositions.find(i); - if(pos != formatting.factorPositions.end()) - stm << ", pos=\"" << formatting.scale*(pos->second.x() - minX) << "," - << formatting.scale*(pos->second.y() - minY) << "!\""; - } - stm << "];\n"; - - // Make factor-variable connections - for(Key key: factorKeys) { - stm << " var" << key << "--" << "factor" << i << ";\n"; - } - - ++ i; + for (const KeyVector& factorKeys : structure) { + writer.processFactor(i++, factorKeys, boost::none, &os); } } else { // Create factors and variable connections - for(size_t i = 0; i < size(); ++i) { + for (size_t i = 0; i < size(); ++i) { const NonlinearFactor::shared_ptr& factor = at(i); - // If null pointer, move on to the next - if (!factor) { - continue; - } - - if (formatting.plotFactorPoints) { - const KeyVector& keys = factor->keys(); - if (formatting.binaryEdges && keys.size() == 2) { - stm << " var" << keys[0] << "--" - << "var" << keys[1] << ";\n"; - } else { - // Make each factor a dot - stm << " factor" << i << "[label=\"\", shape=point"; - { - map::const_iterator pos = - formatting.factorPositions.find(i); - if (pos != formatting.factorPositions.end()) - stm << ", pos=\"" << formatting.scale * (pos->second.x() - minX) - << "," << formatting.scale * (pos->second.y() - minY) - << "!\""; - } - stm << "];\n"; - - // Make factor-variable connections - if (formatting.connectKeysToFactor && factor) { - for (Key key : *factor) { - stm << " var" << key << "--" - << "factor" << i << ";\n"; - } - } - } - } else { - Key k; - bool firstTime = true; - for (Key key : *this->at(i)) { - if (firstTime) { - k = key; - firstTime = false; - continue; - } - stm << " var" << key << "--" - << "var" << k << ";\n"; - k = key; - } + if (factor) { + const KeyVector& factorKeys = factor->keys(); + writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os); } } } - stm << "}\n"; + os << "}\n"; + std::flush(os); +} + +/* ************************************************************************* */ +std::string NonlinearFactorGraph::dot(const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + std::stringstream ss; + dot(ss, values, keyFormatter, writer); + return ss.str(); } /* ************************************************************************* */ -void NonlinearFactorGraph::saveGraph( - const std::string& file, const Values& values, - const GraphvizFormatting& graphvizFormatting, - const KeyFormatter& keyFormatter) const { - std::ofstream of(file); - saveGraph(of, values, graphvizFormatting, keyFormatter); +void NonlinearFactorGraph::saveGraph(const std::string& filename, + const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + std::ofstream of(filename); + dot(of, values, keyFormatter, writer); of.close(); } diff --git a/gtsam/nonlinear/NonlinearFactorGraph.h b/gtsam/nonlinear/NonlinearFactorGraph.h index 61cbbafb98..2fad561be0 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.h +++ b/gtsam/nonlinear/NonlinearFactorGraph.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -41,32 +42,6 @@ namespace gtsam { template class ExpressionFactor; - /** - * Formatting options when saving in GraphViz format using - * NonlinearFactorGraph::saveGraph. - */ - struct GTSAM_EXPORT GraphvizFormatting { - enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; ///< World axes to be assigned to paper axes - Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal paper axis - Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper axis - double figureWidthInches; ///< The figure width on paper in inches - double figureHeightInches; ///< The figure height on paper in inches - double scale; ///< Scale all positions to reduce / increase density - bool mergeSimilarFactors; ///< Merge multiple factors that have the same connectivity - bool plotFactorPoints; ///< Plots each factor as a dot between the variables - bool connectKeysToFactor; ///< Draw a line from each key within a factor to the dot of the factor - bool binaryEdges; ///< just use non-dotted edges for binary factors - std::map factorPositions; ///< (optional for each factor) Manually specify factor "dot" positions. - /// Default constructor sets up robot coordinates. Paper horizontal is robot Y, - /// paper vertical is robot X. Default figure size of 5x5 in. - GraphvizFormatting() : - paperHorizontalAxis(Y), paperVerticalAxis(X), - figureWidthInches(5), figureHeightInches(5), scale(1), - mergeSimilarFactors(false), plotFactorPoints(true), - connectKeysToFactor(true), binaryEdges(true) {} - }; - - /** * A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors, * which derive from NonlinearFactor. The values structures are typically (in SAM) more general @@ -115,21 +90,6 @@ namespace gtsam { /** Test equality */ bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; - /// Write the graph in GraphViz format for visualization - void saveGraph(std::ostream& stm, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /** - * Write the graph in GraphViz format to file for visualization. - * - * This is a wrapper friendly version since wrapped languages don't have - * access to C++ streams. - */ - void saveGraph(const std::string& file, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */ double error(const Values& values) const; @@ -246,7 +206,32 @@ namespace gtsam { emplace_shared>(key, prior, covariance); } - private: + /// @name Graph Display + /// @{ + + using FactorGraph::dot; + using FactorGraph::saveGraph; + + /// Output to graphviz format, stream version, with Values/extra options. + void dot(std::ostream& os, const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& graphvizFormatting = + GraphvizFormatting()) const; + + /// Output to graphviz format string, with Values/extra options. + std::string dot(const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& graphvizFormatting = + GraphvizFormatting()) const; + + /// output to file with graphviz format, with Values/extra options. + void saveGraph(const std::string& filename, const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& graphvizFormatting = + GraphvizFormatting()) const; + /// @} + + private: /** * Linearize from Scatter rather than from Ordering. Made private because @@ -275,6 +260,21 @@ namespace gtsam { Values GTSAM_DEPRECATED updateCholesky(const Values& values, boost::none_t, const Dampen& dampen = nullptr) const {return updateCholesky(values, dampen);} + + /** \deprecated */ + void GTSAM_DEPRECATED saveGraph( + std::ostream& os, const Values& values = Values(), + const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + dot(os, values, keyFormatter, graphvizFormatting); + } + /** \deprecated */ + void GTSAM_DEPRECATED saveGraph( + const std::string& filename, const Values& values, + const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + saveGraph(filename, values, keyFormatter, graphvizFormatting); + } #endif }; diff --git a/gtsam/nonlinear/Values-inl.h b/gtsam/nonlinear/Values-inl.h index 8ebdcab17c..dfcb7e174c 100644 --- a/gtsam/nonlinear/Values-inl.h +++ b/gtsam/nonlinear/Values-inl.h @@ -391,4 +391,10 @@ namespace gtsam { update(j, static_cast(GenericValue(val))); } + // insert_or_assign with templated value + template + void Values::insert_or_assign(Key j, const ValueType& val) { + insert_or_assign(j, static_cast(GenericValue(val))); + } + } diff --git a/gtsam/nonlinear/Values.cpp b/gtsam/nonlinear/Values.cpp index ebc9c51f67..adadc99c06 100644 --- a/gtsam/nonlinear/Values.cpp +++ b/gtsam/nonlinear/Values.cpp @@ -171,6 +171,25 @@ namespace gtsam { } } + /* ************************************************************************ */ + void Values::insert_or_assign(Key j, const Value& val) { + if (this->exists(j)) { + // If key already exists, perform an update. + this->update(j, val); + } else { + // If key does not exist, perform an insert. + this->insert(j, val); + } + } + + /* ************************************************************************ */ + void Values::insert_or_assign(const Values& values) { + for (const_iterator key_value = values.begin(); key_value != values.end(); + ++key_value) { + this->insert_or_assign(key_value->key, key_value->value); + } + } + /* ************************************************************************* */ void Values::erase(Key j) { KeyValueMap::iterator item = values_.find(j); diff --git a/gtsam/nonlinear/Values.h b/gtsam/nonlinear/Values.h index 207f355407..cfe6347b50 100644 --- a/gtsam/nonlinear/Values.h +++ b/gtsam/nonlinear/Values.h @@ -285,6 +285,19 @@ namespace gtsam { /** update the current available values without adding new ones */ void update(const Values& values); + /// If key j exists, update value, else perform an insert. + void insert_or_assign(Key j, const Value& val); + + /** + * Update a set of variables. + * If any variable key doe not exist, then perform an insert. + */ + void insert_or_assign(const Values& values); + + /// Templated version to insert_or_assign a variable with the given j. + template + void insert_or_assign(Key j, const ValueType& val); + /** Remove a variable from the config, throws KeyDoesNotExist if j is not present */ void erase(Key j); diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index 152c4b8e74..84c4939f49 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -131,9 +131,6 @@ class Ordering { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -196,10 +193,12 @@ class NonlinearFactorGraph { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - - void saveGraph(const string& s) const; + string dot( + const gtsam::Values& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); + void saveGraph(const string& s, const gtsam::Values& values, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; }; #include @@ -275,6 +274,7 @@ class Values { void insert(const gtsam::Values& values); void update(const gtsam::Values& values); + void insert_or_assign(const gtsam::Values& values); void erase(size_t j); void swap(gtsam::Values& values); @@ -289,9 +289,6 @@ class Values { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // New in 4.0, we have to specialize every insert/update/at to generate // wrappers Instead of the old: void insert(size_t j, const gtsam::Value& // value); void update(size_t j, const gtsam::Value& val); gtsam::Value @@ -351,6 +348,32 @@ class Values { void update(size_t j, Matrix matrix); void update(size_t j, double c); + void insert_or_assign(size_t j, const gtsam::Point2& point2); + void insert_or_assign(size_t j, const gtsam::Point3& point3); + void insert_or_assign(size_t j, const gtsam::Rot2& rot2); + void insert_or_assign(size_t j, const gtsam::Pose2& pose2); + void insert_or_assign(size_t j, const gtsam::SO3& R); + void insert_or_assign(size_t j, const gtsam::SO4& Q); + void insert_or_assign(size_t j, const gtsam::SOn& P); + void insert_or_assign(size_t j, const gtsam::Rot3& rot3); + void insert_or_assign(size_t j, const gtsam::Pose3& pose3); + void insert_or_assign(size_t j, const gtsam::Unit3& unit3); + void insert_or_assign(size_t j, const gtsam::Cal3_S2& cal3_s2); + void insert_or_assign(size_t j, const gtsam::Cal3DS2& cal3ds2); + void insert_or_assign(size_t j, const gtsam::Cal3Bundler& cal3bundler); + void insert_or_assign(size_t j, const gtsam::Cal3Fisheye& cal3fisheye); + void insert_or_assign(size_t j, const gtsam::Cal3Unified& cal3unified); + void insert_or_assign(size_t j, const gtsam::EssentialMatrix& essential_matrix); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::imuBias::ConstantBias& constant_bias); + void insert_or_assign(size_t j, const gtsam::NavState& nav_state); + void insert_or_assign(size_t j, Vector vector); + void insert_or_assign(size_t j, Matrix matrix); + void insert_or_assign(size_t j, double c); + template @@ -824,9 +853,6 @@ virtual class PriorFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include diff --git a/gtsam/geometry/tests/testUtilities.cpp b/gtsam/nonlinear/tests/testUtilities.cpp similarity index 68% rename from gtsam/geometry/tests/testUtilities.cpp rename to gtsam/nonlinear/tests/testUtilities.cpp index 25ac3acc87..55a7fdb136 100644 --- a/gtsam/geometry/tests/testUtilities.cpp +++ b/gtsam/nonlinear/tests/testUtilities.cpp @@ -21,7 +21,6 @@ #include #include #include -#include #include using namespace gtsam; @@ -55,6 +54,26 @@ TEST(Utilities, ExtractPoint3) { EXPECT_LONGS_EQUAL(2, all_points.rows()); } +/* ************************************************************************* */ +TEST(Utilities, ExtractVector) { + // Test normal case with 3 vectors and 1 non-vector (ignore non-vector) + auto values = Values(); + values.insert(X(0), (Vector(4) << 1, 2, 3, 4).finished()); + values.insert(X(2), (Vector(4) << 13, 14, 15, 16).finished()); + values.insert(X(1), (Vector(4) << 6, 7, 8, 9).finished()); + values.insert(X(3), Pose3()); + auto actual = utilities::extractVectors(values, 'x'); + auto expected = + (Matrix(3, 4) << 1, 2, 3, 4, 6, 7, 8, 9, 13, 14, 15, 16).finished(); + EXPECT(assert_equal(expected, actual)); + + // Check that mis-sized vectors fail + values.insert(X(4), (Vector(2) << 1, 2).finished()); + THROWS_EXCEPTION(utilities::extractVectors(values, 'x')); + values.update(X(4), (Vector(6) << 1, 2, 3, 4, 5, 6).finished()); + THROWS_EXCEPTION(utilities::extractVectors(values, 'x')); +} + /* ************************************************************************* */ int main() { srand(time(nullptr)); diff --git a/gtsam/nonlinear/tests/testValues.cpp b/gtsam/nonlinear/tests/testValues.cpp index b894f48169..bed2a8af93 100644 --- a/gtsam/nonlinear/tests/testValues.cpp +++ b/gtsam/nonlinear/tests/testValues.cpp @@ -172,6 +172,22 @@ TEST( Values, update_element ) CHECK(assert_equal((Vector)v2, cfg.at(key1))); } +TEST(Values, InsertOrAssign) { + Values values; + Key X(0); + double x = 1; + + CHECK(values.size() == 0); + // This should perform an insert. + values.insert_or_assign(X, x); + EXPECT(assert_equal(values.at(X), x)); + + // This should perform an update. + double y = 2; + values.insert_or_assign(X, y); + EXPECT(assert_equal(values.at(X), y)); +} + /* ************************************************************************* */ TEST(Values, basic_functions) { diff --git a/gtsam/nonlinear/utilities.h b/gtsam/nonlinear/utilities.h index fdc1da2c4e..d2b38d3743 100644 --- a/gtsam/nonlinear/utilities.h +++ b/gtsam/nonlinear/utilities.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -162,6 +163,34 @@ Matrix extractPose3(const Values& values) { return result; } +/// Extract all Vector values with a given symbol character into an mxn matrix, +/// where m is the number of symbols that match the character and n is the +/// dimension of the variables. If not all variables have dimension n, then a +/// runtime error will be thrown. The order of returned values are sorted by +/// the symbol. +/// For example, calling extractVector(values, 'x'), where values contains 200 +/// variables x1, x2, ..., x200 of type Vector each 5-dimensional, will return a +/// 200x5 matrix with row i containing xi. +Matrix extractVectors(const Values& values, char c) { + Values::ConstFiltered vectors = + values.filter(Symbol::ChrTest(c)); + if (vectors.size() == 0) { + return Matrix(); + } + auto dim = vectors.begin()->value.size(); + Matrix result(vectors.size(), dim); + Eigen::Index rowi = 0; + for (const auto& kv : vectors) { + if (kv.value.size() != dim) { + throw std::runtime_error( + "Tried to extract different-sized vectors into a single matrix"); + } + result.row(rowi) = kv.value; + ++rowi; + } + return result; +} + /// Perturb all Point2 values using normally distributed noise void perturbPoint2(Values& values, double sigma, int32_t seed = 42u) { noiseModel::Isotropic::shared_ptr model = diff --git a/gtsam/sfm/MFAS.h b/gtsam/sfm/MFAS.h index decfbed0f5..151b318ad8 100644 --- a/gtsam/sfm/MFAS.h +++ b/gtsam/sfm/MFAS.h @@ -48,7 +48,7 @@ namespace gtsam { unit translations in a projection direction. @addtogroup SFM */ -class MFAS { +class GTSAM_EXPORT MFAS { public: // used to represent edges between two nodes in the graph. When used in // translation averaging for global SfM diff --git a/gtsam/slam/EssentialMatrixFactor.h b/gtsam/slam/EssentialMatrixFactor.h index 787efac51e..5997ad2247 100644 --- a/gtsam/slam/EssentialMatrixFactor.h +++ b/gtsam/slam/EssentialMatrixFactor.h @@ -1,7 +1,20 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2014, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + /* - * @file EssentialMatrixFactor.cpp + * @file EssentialMatrixFactor.h * @brief EssentialMatrixFactor class * @author Frank Dellaert + * @author Ayush Baid + * @author Akshay Krishnan * @date December 17, 2013 */ diff --git a/gtsam/slam/ReadMe.md b/gtsam/slam/README.md similarity index 100% rename from gtsam/slam/ReadMe.md rename to gtsam/slam/README.md diff --git a/gtsam/slam/SmartFactorBase.h b/gtsam/slam/SmartFactorBase.h index ddf56b2891..209c1196d5 100644 --- a/gtsam/slam/SmartFactorBase.h +++ b/gtsam/slam/SmartFactorBase.h @@ -47,7 +47,7 @@ namespace gtsam { * @tparam CAMERA should behave like a PinholeCamera. */ template -class SmartFactorBase: public NonlinearFactor { +class GTSAM_EXPORT SmartFactorBase: public NonlinearFactor { private: typedef NonlinearFactor Base; diff --git a/gtsam/slam/SmartProjectionPoseFactor.h b/gtsam/slam/SmartProjectionPoseFactor.h index c7b1d54245..3cd69c46f3 100644 --- a/gtsam/slam/SmartProjectionPoseFactor.h +++ b/gtsam/slam/SmartProjectionPoseFactor.h @@ -41,11 +41,10 @@ namespace gtsam { * If the calibration should be optimized, as well, use SmartProjectionFactor instead! * @addtogroup SLAM */ -template -class SmartProjectionPoseFactor: public SmartProjectionFactor< - PinholePose > { - -private: +template +class GTSAM_EXPORT SmartProjectionPoseFactor + : public SmartProjectionFactor > { + private: typedef PinholePose Camera; typedef SmartProjectionFactor Base; typedef SmartProjectionPoseFactor This; @@ -156,7 +155,6 @@ class SmartProjectionPoseFactor: public SmartProjectionFactor< ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar & BOOST_SERIALIZATION_NVP(K_); } - }; // end of class declaration diff --git a/gtsam/slam/SmartProjectionRigFactor.h b/gtsam/slam/SmartProjectionRigFactor.h index 8d6918b3ec..149c129288 100644 --- a/gtsam/slam/SmartProjectionRigFactor.h +++ b/gtsam/slam/SmartProjectionRigFactor.h @@ -54,6 +54,8 @@ class SmartProjectionRigFactor : public SmartProjectionFactor { typedef SmartProjectionFactor Base; typedef SmartProjectionRigFactor This; typedef typename CAMERA::CalibrationType CALIBRATION; + typedef typename CAMERA::Measurement MEASUREMENT; + typedef typename CAMERA::MeasurementVector MEASUREMENTS; static const int DimPose = 6; ///< Pose3 dimension static const int ZDim = 2; ///< Measurement dimension @@ -118,7 +120,7 @@ class SmartProjectionRigFactor : public SmartProjectionFactor { * @param cameraId ID of the camera in the rig taking the measurement (default * 0) */ - void add(const Point2& measured, const Key& poseKey, + void add(const MEASUREMENT& measured, const Key& poseKey, const size_t& cameraId = 0) { // store measurement and key this->measured_.push_back(measured); @@ -144,7 +146,7 @@ class SmartProjectionRigFactor : public SmartProjectionFactor { * @param cameraIds IDs of the cameras in the rig taking each measurement * (same order as the measurements) */ - void add(const Point2Vector& measurements, const KeyVector& poseKeys, + void add(const MEASUREMENTS& measurements, const KeyVector& poseKeys, const FastVector& cameraIds = FastVector()) { if (poseKeys.size() != measurements.size() || (poseKeys.size() != cameraIds.size() && cameraIds.size() != 0)) { diff --git a/gtsam/slam/TriangulationFactor.h b/gtsam/slam/TriangulationFactor.h index f12053d29f..40e9538e25 100644 --- a/gtsam/slam/TriangulationFactor.h +++ b/gtsam/slam/TriangulationFactor.h @@ -33,18 +33,18 @@ class TriangulationFactor: public NoiseModelFactor1 { public: /// CAMERA type - typedef CAMERA Camera; + using Camera = CAMERA; protected: /// shorthand for base class type - typedef NoiseModelFactor1 Base; + using Base = NoiseModelFactor1; /// shorthand for this class - typedef TriangulationFactor This; + using This = TriangulationFactor; /// shorthand for measurement type, e.g. Point2 or StereoPoint2 - typedef typename CAMERA::Measurement Measurement; + using Measurement = typename CAMERA::Measurement; // Keep a copy of measurement and calibration for I/O const CAMERA camera_; ///< CAMERA in which this landmark was seen @@ -55,9 +55,10 @@ class TriangulationFactor: public NoiseModelFactor1 { const bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) public: + EIGEN_MAKE_ALIGNED_OPERATOR_NEW /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; + using shared_ptr = boost::shared_ptr; /// Default constructor TriangulationFactor() : @@ -129,7 +130,7 @@ class TriangulationFactor: public NoiseModelFactor1 { << std::endl; if (throwCheirality_) throw e; - return Eigen::Matrix::dimension,1>::Constant(2.0 * camera_.calibration().fx()); + return camera_.defaultErrorWhenTriangulatingBehindCamera(); } } diff --git a/gtsam/slam/slam.i b/gtsam/slam/slam.i index 60000dbab1..d276c4f2ec 100644 --- a/gtsam/slam/slam.i +++ b/gtsam/slam/slam.i @@ -21,9 +21,6 @@ virtual class BetweenFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -168,6 +165,10 @@ template virtual class PoseTranslationPrior : gtsam::NoiseModelFactor { PoseTranslationPrior(size_t key, const POSE& pose_z, const gtsam::noiseModel::Base* noiseModel); + POSE::Translation measured() const; + + // enabling serialization functionality + void serialize() const; }; typedef gtsam::PoseTranslationPrior PoseTranslationPrior2D; @@ -178,6 +179,7 @@ template virtual class PoseRotationPrior : gtsam::NoiseModelFactor { PoseRotationPrior(size_t key, const POSE& pose_z, const gtsam::noiseModel::Base* noiseModel); + POSE::Rotation measured() const; }; typedef gtsam::PoseRotationPrior PoseRotationPrior2D; @@ -188,6 +190,21 @@ virtual class EssentialMatrixFactor : gtsam::NoiseModelFactor { EssentialMatrixFactor(size_t key, const gtsam::Point2& pA, const gtsam::Point2& pB, const gtsam::noiseModel::Base* noiseModel); + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::EssentialMatrixFactor& other, double tol) const; + Vector evaluateError(const gtsam::EssentialMatrix& E) const; +}; + +#include +virtual class EssentialMatrixConstraint : gtsam::NoiseModelFactor { + EssentialMatrixConstraint(size_t key1, size_t key2, const gtsam::EssentialMatrix &measuredE, + const gtsam::noiseModel::Base *model); + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::EssentialMatrixConstraint& other, double tol) const; + Vector evaluateError(const gtsam::Pose3& p1, const gtsam::Pose3& p2) const; + const gtsam::EssentialMatrix& measured() const; }; #include @@ -211,9 +228,6 @@ class SfmTrack { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // enabling function to compare objects bool equals(const gtsam::SfmTrack& expected, double tol) const; }; @@ -230,9 +244,6 @@ class SfmData { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // enabling function to compare objects bool equals(const gtsam::SfmData& expected, double tol) const; }; diff --git a/gtsam/slam/tests/smartFactorScenarios.h b/gtsam/slam/tests/smartFactorScenarios.h index b17ffdac6c..66be08c674 100644 --- a/gtsam/slam/tests/smartFactorScenarios.h +++ b/gtsam/slam/tests/smartFactorScenarios.h @@ -17,11 +17,13 @@ */ #pragma once -#include -#include -#include -#include #include +#include +#include +#include +#include +#include + #include "../SmartProjectionRigFactor.h" using namespace std; @@ -44,7 +46,7 @@ Pose3 pose_above = level_pose * Pose3(Rot3(), Point3(0, -1, 0)); // Create a noise unit2 for the pixel error static SharedNoiseModel unit2(noiseModel::Unit::Create(2)); -static double fov = 60; // degrees +static double fov = 60; // degrees static size_t w = 640, h = 480; /* ************************************************************************* */ @@ -63,7 +65,7 @@ Camera cam2(pose_right, K2); Camera cam3(pose_above, K2); typedef GeneralSFMFactor SFMFactor; SmartProjectionParams params; -} +} // namespace vanilla /* ************************************************************************* */ // default Cal3_S2 poses @@ -78,7 +80,7 @@ Camera level_camera_right(pose_right, sharedK); Camera cam1(level_pose, sharedK); Camera cam2(pose_right, sharedK); Camera cam3(pose_above, sharedK); -} +} // namespace vanillaPose /* ************************************************************************* */ // default Cal3_S2 poses @@ -93,7 +95,7 @@ Camera level_camera_right(pose_right, sharedK2); Camera cam1(level_pose, sharedK2); Camera cam2(pose_right, sharedK2); Camera cam3(pose_above, sharedK2); -} +} // namespace vanillaPose2 /* *************************************************************************/ // Cal3Bundler cameras @@ -111,7 +113,8 @@ Camera cam1(level_pose, K); Camera cam2(pose_right, K); Camera cam3(pose_above, K); typedef GeneralSFMFactor SFMFactor; -} +} // namespace bundler + /* *************************************************************************/ // Cal3Bundler poses namespace bundlerPose { @@ -119,35 +122,50 @@ typedef PinholePose Camera; typedef CameraSet Cameras; typedef SmartProjectionPoseFactor SmartFactor; typedef SmartProjectionRigFactor SmartRigFactor; -static boost::shared_ptr sharedBundlerK( - new Cal3Bundler(500, 1e-3, 1e-3, 1000, 2000)); +static boost::shared_ptr sharedBundlerK(new Cal3Bundler(500, 1e-3, + 1e-3, 1000, + 2000)); Camera level_camera(level_pose, sharedBundlerK); Camera level_camera_right(pose_right, sharedBundlerK); Camera cam1(level_pose, sharedBundlerK); Camera cam2(pose_right, sharedBundlerK); Camera cam3(pose_above, sharedBundlerK); -} +} // namespace bundlerPose + +/* ************************************************************************* */ +// sphericalCamera +namespace sphericalCamera { +typedef SphericalCamera Camera; +typedef CameraSet Cameras; +typedef SmartProjectionRigFactor SmartFactorP; +static EmptyCal::shared_ptr emptyK(new EmptyCal()); +Camera level_camera(level_pose); +Camera level_camera_right(pose_right); +Camera cam1(level_pose); +Camera cam2(pose_right); +Camera cam3(pose_above); +} // namespace sphericalCamera /* *************************************************************************/ -template +template CAMERA perturbCameraPose(const CAMERA& camera) { - Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), - Point3(0.5, 0.1, 0.3)); + Pose3 noise_pose = + Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), Point3(0.5, 0.1, 0.3)); Pose3 cameraPose = camera.pose(); Pose3 perturbedCameraPose = cameraPose.compose(noise_pose); return CAMERA(perturbedCameraPose, camera.calibration()); } -template -void projectToMultipleCameras(const CAMERA& cam1, const CAMERA& cam2, - const CAMERA& cam3, Point3 landmark, typename CAMERA::MeasurementVector& measurements_cam) { - Point2 cam1_uv1 = cam1.project(landmark); - Point2 cam2_uv1 = cam2.project(landmark); - Point2 cam3_uv1 = cam3.project(landmark); +template +void projectToMultipleCameras( + const CAMERA& cam1, const CAMERA& cam2, const CAMERA& cam3, Point3 landmark, + typename CAMERA::MeasurementVector& measurements_cam) { + typename CAMERA::Measurement cam1_uv1 = cam1.project(landmark); + typename CAMERA::Measurement cam2_uv1 = cam2.project(landmark); + typename CAMERA::Measurement cam3_uv1 = cam3.project(landmark); measurements_cam.push_back(cam1_uv1); measurements_cam.push_back(cam2_uv1); measurements_cam.push_back(cam3_uv1); } /* ************************************************************************* */ - diff --git a/gtsam/slam/tests/testEssentialMatrixConstraint.cpp b/gtsam/slam/tests/testEssentialMatrixConstraint.cpp index 080239b350..2faac24d1b 100644 --- a/gtsam/slam/tests/testEssentialMatrixConstraint.cpp +++ b/gtsam/slam/tests/testEssentialMatrixConstraint.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file testEssentialMatrixConstraint.cpp + * @file TestEssentialMatrixConstraint.cpp * @brief Unit tests for EssentialMatrixConstraint Class * @author Frank Dellaert * @author Pablo Alcantarilla diff --git a/gtsam/slam/tests/testSmartProjectionRigFactor.cpp b/gtsam/slam/tests/testSmartProjectionRigFactor.cpp index b8150a1aa2..b4876b27ea 100644 --- a/gtsam/slam/tests/testSmartProjectionRigFactor.cpp +++ b/gtsam/slam/tests/testSmartProjectionRigFactor.cpp @@ -55,8 +55,6 @@ Key cameraId3 = 2; static Point2 measurement1(323.0, 240.0); LevenbergMarquardtParams lmParams; -// Make more verbose like so (in tests): -// params.verbosityLM = LevenbergMarquardtParams::SUMMARY; /* ************************************************************************* */ // default Cal3_S2 poses with rolling shutter effect @@ -1187,10 +1185,9 @@ TEST(SmartProjectionRigFactor, optimization_3poses_measurementsFromSamePose) { // this factor is slightly slower (but comparable) to original // SmartProjectionPoseFactor //-Total: 0 CPU (0 times, 0 wall, 0.17 children, min: 0 max: 0) -//| -SmartRigFactor LINEARIZE: 0.06 CPU -//(10000 times, 0.061226 wall, 0.06 children, min: 0 max: 0) -//| -SmartPoseFactor LINEARIZE: 0.06 CPU -//(10000 times, 0.073037 wall, 0.06 children, min: 0 max: 0) +//| -SmartRigFactor LINEARIZE: 0.05 CPU (10000 times, 0.057952 wall, 0.05 +// children, min: 0 max: 0) | -SmartPoseFactor LINEARIZE: 0.05 CPU (10000 +// times, 0.069647 wall, 0.05 children, min: 0 max: 0) /* *************************************************************************/ TEST(SmartProjectionRigFactor, timing) { using namespace vanillaRig; @@ -1249,6 +1246,355 @@ TEST(SmartProjectionRigFactor, timing) { } #endif +/* *************************************************************************/ +TEST(SmartProjectionFactorP, optimization_3poses_sphericalCamera) { + using namespace sphericalCamera; + Camera::MeasurementVector measurements_lmk1, measurements_lmk2, + measurements_lmk3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, + measurements_lmk1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, + measurements_lmk2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, + measurements_lmk3); + + // create inputs + KeyVector keys; + keys.push_back(x1); + keys.push_back(x2); + keys.push_back(x3); + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), emptyK)); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + SmartFactorP::shared_ptr smartFactor1( + new SmartFactorP(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, keys); + + SmartFactorP::shared_ptr smartFactor2( + new SmartFactorP(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, keys); + + SmartFactorP::shared_ptr smartFactor3( + new SmartFactorP(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, keys); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 100), + Point3(0.2, 0.2, 0.2)); // note: larger noise! + + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + + DOUBLES_EQUAL(0.94148963675515274, graph.error(values), 1e-9); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + + EXPECT(assert_equal(pose_above, result.at(x3), 1e-5)); +} + +#ifndef DISABLE_TIMING +#include +// using spherical camera is slightly slower (but comparable) to +// PinholePose +//| -SmartFactorP spherical LINEARIZE: 0.01 CPU (1000 times, 0.008178 wall, +// 0.01 children, min: 0 max: 0) | -SmartFactorP pinhole LINEARIZE: 0.01 CPU +//(1000 times, 0.005717 wall, 0.01 children, min: 0 max: 0) +/* *************************************************************************/ +TEST(SmartProjectionFactorP, timing_sphericalCamera) { + // create common data + Rot3 R = Rot3::identity(); + Pose3 pose1 = Pose3(R, Point3(0, 0, 0)); + Pose3 pose2 = Pose3(R, Point3(1, 0, 0)); + Pose3 body_P_sensorId = Pose3::identity(); + Point3 landmark1(0, 0, 10); + + // create spherical data + EmptyCal::shared_ptr emptyK; + SphericalCamera cam1_sphere(pose1, emptyK), cam2_sphere(pose2, emptyK); + // Project 2 landmarks into 2 cameras + std::vector measurements_lmk1_sphere; + measurements_lmk1_sphere.push_back(cam1_sphere.project(landmark1)); + measurements_lmk1_sphere.push_back(cam2_sphere.project(landmark1)); + + // create Cal3_S2 data + static Cal3_S2::shared_ptr sharedKSimple(new Cal3_S2(100, 100, 0, 0, 0)); + PinholePose cam1(pose1, sharedKSimple), cam2(pose2, sharedKSimple); + // Project 2 landmarks into 2 cameras + std::vector measurements_lmk1; + measurements_lmk1.push_back(cam1.project(landmark1)); + measurements_lmk1.push_back(cam2.project(landmark1)); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + + size_t nrTests = 1000; + + for (size_t i = 0; i < nrTests; i++) { + boost::shared_ptr> cameraRig( + new CameraSet()); // single camera in the rig + cameraRig->push_back(SphericalCamera(body_P_sensorId, emptyK)); + + SmartProjectionRigFactor::shared_ptr smartFactorP( + new SmartProjectionRigFactor(model, cameraRig, + params)); + smartFactorP->add(measurements_lmk1_sphere[0], x1); + smartFactorP->add(measurements_lmk1_sphere[1], x1); + + Values values; + values.insert(x1, pose1); + values.insert(x2, pose2); + gttic_(SmartFactorP_spherical_LINEARIZE); + smartFactorP->linearize(values); + gttoc_(SmartFactorP_spherical_LINEARIZE); + } + + for (size_t i = 0; i < nrTests; i++) { + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(body_P_sensorId, sharedKSimple)); + + SmartProjectionRigFactor>::shared_ptr smartFactorP2( + new SmartProjectionRigFactor>(model, cameraRig, + params)); + smartFactorP2->add(measurements_lmk1[0], x1); + smartFactorP2->add(measurements_lmk1[1], x1); + + Values values; + values.insert(x1, pose1); + values.insert(x2, pose2); + gttic_(SmartFactorP_pinhole_LINEARIZE); + smartFactorP2->linearize(values); + gttoc_(SmartFactorP_pinhole_LINEARIZE); + } + tictoc_print_(); +} +#endif + +/* *************************************************************************/ +TEST(SmartProjectionFactorP, 2poses_rankTol) { + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, -0.1, 0.0)); // 10cm to the right of poseA + Point3 landmarkL = Point3(5.0, 0.0, 0.0); // 5m in front of poseA + + // triangulate from a stereo with 10cm baseline, assuming standard calibration + { // default rankTol = 1 gives a valid point (compare with calibrated and + // spherical cameras below) + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + + Camera cam1(poseA, sharedK); + Camera cam2(poseB, sharedK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(1); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), sharedK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } + // triangulate from a stereo with 10cm baseline, assuming canonical + // calibration + { // default rankTol = 1 or 0.1 gives a degenerate point, which is + // undesirable for a point 5m away and 10cm baseline + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + static Cal3_S2::shared_ptr canonicalK( + new Cal3_S2(1.0, 1.0, 0.0, 0.0, 0.0)); // canonical camera + + Camera cam1(poseA, canonicalK); + Camera cam2(poseB, canonicalK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), canonicalK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.degenerate()); // valid triangulation + } + // triangulate from a stereo with 10cm baseline, assuming canonical + // calibration + { // smaller rankTol = 0.01 gives a valid point (compare with calibrated and + // spherical cameras below) + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + static Cal3_S2::shared_ptr canonicalK( + new Cal3_S2(1.0, 1.0, 0.0, 0.0, 0.0)); // canonical camera + + Camera cam1(poseA, canonicalK); + Camera cam2(poseB, canonicalK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.01); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), canonicalK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } +} + +/* *************************************************************************/ +TEST(SmartProjectionFactorP, 2poses_sphericalCamera_rankTol) { + typedef SphericalCamera Camera; + typedef SmartProjectionRigFactor SmartRigFactor; + EmptyCal::shared_ptr emptyK(new EmptyCal()); + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, -0.1, 0.0)); // 10cm to the right of poseA + Point3 landmarkL = Point3(5.0, 0.0, 0.0); // 5m in front of poseA + + Camera cam1(poseA); + Camera cam2(poseB); + + boost::shared_ptr> cameraRig( + new CameraSet()); // single camera in the rig + cameraRig->push_back(SphericalCamera(Pose3::identity(), emptyK)); + + // TRIANGULATION TEST WITH DEFAULT RANK TOL + { // rankTol = 1 or 0.1 gives a degenerate point, which is undesirable for a + // point 5m away and 10cm baseline + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.degenerate()); // not enough parallax + } + // SAME TEST WITH SMALLER RANK TOL + { // rankTol = 0.01 gives a valid point + // By playing with this test, we can show we can triangulate also with a + // baseline of 5cm (even for points far away, >100m), but the test fails + // when the baseline becomes 1cm. This suggests using rankTol = 0.01 and + // setting a reasonable max landmark distance to obtain best results. + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.01); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index d6e1c64530..bff524bc2e 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double AllDiff::operator()(const Values& values) const { +double AllDiff::operator()(const DiscreteValues& values) const { std::set taken; // record values taken by keys for (Key dkey : keys_) { size_t value = values.at(dkey); // get the value for that key @@ -57,21 +57,25 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool AllDiff::ensureArcConsistency(size_t j, - std::vector& domains) const { +bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { + Domain& Dj = domains->at(j); + // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. + // a value in domains->at(j) that does not occur in any other connected domain. // If found, we make this a singleton... // TODO: make a new constraint where this really is true - Domain& Dj = domains[j]; - if (Dj.checkAllDiff(keys_, domains)) return true; + boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); + if (maybeChanged) { + Dj = *maybeChanged; + return true; + } - // Check all other domains for singletons and erase corresponding values + // Check all other domains for singletons and erase corresponding values. // This is the same as arc-consistency on the equivalent binary constraints bool changed = false; for (Key k : keys_) if (k != j) { - const Domain& Dk = domains[k]; + const Domain& Dk = domains->at(k); if (Dk.isSingleton()) { // check if singleton size_t value = Dk.firstValue(); if (Dj.contains(value)) { @@ -84,7 +88,7 @@ bool AllDiff::ensureArcConsistency(size_t j, } /* ************************************************************************* */ -Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { +Constraint::shared_ptr AllDiff::partiallyApply(const DiscreteValues& values) const { DiscreteKeys newKeys; // loop over keys and add them only if they do not appear in values for (Key k : keys_) @@ -96,10 +100,10 @@ Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { /* ************************************************************************* */ Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { - DiscreteFactor::Values known; + const Domains& domains) const { + DiscreteValues known; for (Key k : keys_) { - const Domain& Dk = domains[k]; + const Domain& Dk = domains.at(k); if (Dk.isSingleton()) known[k] = Dk.firstValue(); } return partiallyApply(known); diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index b0fd1d631e..9496fc1a63 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -13,11 +13,8 @@ namespace gtsam { /** - * General AllDiff constraint - * Returns 1 if values for all keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Key and an Key. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. + * General AllDiff constraint. + * Returns 1 if values for all keys are different, 0 otherwise. */ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { std::map cardinalities_; @@ -28,7 +25,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } public: - /// Constructor + /// Construct from keys. AllDiff(const DiscreteKeys& dkeys); // print @@ -48,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } /// Calculate value = expensive ! - double operator()(const Values& values) const override; + double operator()(const DiscreteValues& values) const override; /// Convert into a decisiontree, can be *very* expensive ! DecisionTreeFactor toDecisionTreeFactor() const override; @@ -57,21 +54,19 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override; + Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override; /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector&) const override; + const Domains&) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index d8e1a590aa..b207acb9d8 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -15,10 +15,7 @@ namespace gtsam { /** * Binary AllDiff constraint - * Returns 1 if values for two keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Index and an Index. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. + * Returns 1 if values for two keys are different, 0 otherwise. */ class BinaryAllDiff : public Constraint { size_t cardinality0_, cardinality1_; /// cardinality @@ -50,7 +47,7 @@ class BinaryAllDiff : public Constraint { } /// Calculate value - double operator()(const Values& values) const override { + double operator()(const DiscreteValues& values) const override { return (double)(values.at(keys_[0]) != values.at(keys_[1])); } @@ -73,25 +70,25 @@ class BinaryAllDiff : public Constraint { } /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override { - // throw std::runtime_error( - // "BinaryAllDiff::ensureArcConsistency not implemented"); + bool ensureArcConsistency(Key j, Domains* domains) const override { + throw std::runtime_error( + "BinaryAllDiff::ensureArcConsistency not implemented"); return false; } /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override { + Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override { throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); } /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector&) const override { + const Domains&) const override { throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); } }; diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index b1d70dc6e6..283c992f13 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -14,95 +14,86 @@ using namespace std; namespace gtsam { /// Find the best total assignment - can be expensive -CSP::sharedValues CSP::optimalAssignment() const { +DiscreteValues CSP::optimalAssignment() const { DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(); - sharedValues mpe = chordal->optimize(); - return mpe; + return chordal->optimize(); } /// Find the best total assignment - can be expensive -CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const { +DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const { DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering); - sharedValues mpe = chordal->optimize(); - return mpe; + return chordal->optimize(); } -void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, - bool print) const { +bool CSP::runArcConsistency(const VariableIndex& index, + Domains* domains) const { + bool changed = false; + + // iterate over all variables in the index + for (auto entry : index) { + // Get the variable's key and associated factors: + const Key key = entry.first; + const FactorIndices& factors = entry.second; + + // If this domain is already a singleton, we do nothing. + if (domains->at(key).isSingleton()) continue; + + // Otherwise, loop over all factors/constraints for variable with given key. + for (size_t f : factors) { + // If this factor is a constraint, call its ensureArcConsistency method: + auto constraint = boost::dynamic_pointer_cast((*this)[f]); + if (constraint) { + changed = constraint->ensureArcConsistency(key, domains) || changed; + } + } + } + return changed; +} + +// TODO(dellaert): This is AC1, which is inefficient as any change will cause +// the algorithm to revisit *all* variables again. Implement AC3. +Domains CSP::runArcConsistency(size_t cardinality, size_t maxIterations) const { // Create VariableIndex VariableIndex index(*this); - // index.print(); - - size_t n = index.size(); // Initialize domains - std::vector domains; - for (size_t j = 0; j < n; j++) - domains.push_back(Domain(DiscreteKey(j, cardinality))); - - // Create array of flags indicating a domain changed or not - std::vector changed(n); + Domains domains; + for (auto entry : index) { + const Key key = entry.first; + domains.emplace(key, DiscreteKey(key, cardinality)); + } - // iterate nrIterations over entire grid - for (size_t it = 0; it < nrIterations; it++) { - bool anyChange = false; - // iterate over all cells - for (size_t v = 0; v < n; v++) { - // keep track of which domains changed - changed[v] = false; - // loop over all factors/constraints for variable v - const FactorIndices& factors = index[v]; - for (size_t f : factors) { - // if not already a singleton - if (!domains[v].isSingleton()) { - // get the constraint and call its ensureArcConsistency method - Constraint::shared_ptr constraint = - boost::dynamic_pointer_cast((*this)[f]); - if (!constraint) - throw runtime_error("CSP:runArcConsistency: non-constraint factor"); - changed[v] = - constraint->ensureArcConsistency(v, domains) || changed[v]; - } - } // f - if (changed[v]) anyChange = true; - } // v - if (!anyChange) break; - // TODO: Sudoku specific hack - if (print) { - if (cardinality == 9 && n == 81) { - for (size_t i = 0, v = 0; i < (size_t)std::sqrt((double)n); i++) { - for (size_t j = 0; j < (size_t)std::sqrt((double)n); j++, v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // i - cout << endl; - } // j - } else { - for (size_t v = 0; v < n; v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // v - } - cout << endl; - } // print - } // it + // Iterate until convergence or not a single domain changed. + for (size_t it = 0; it < maxIterations; it++) { + bool changed = runArcConsistency(index, &domains); + if (!changed) break; + } + return domains; +} -#ifndef INPROGRESS - // Now create new problem with all singleton variables removed - // We do this by adding simplifying all factors using parial application +CSP CSP::partiallyApply(const Domains& domains) const { + // Create new problem with all singleton variables removed + // We do this by adding simplifying all factors using partial application. // TODO: create a new ordering as we go, to ensure a connected graph // KeyOrdering ordering; // vector dkeys; + CSP new_csp; + + // Add tightened domains as new factors: + for (auto key_domain : domains) { + new_csp.emplace_shared(key_domain.second); + } + + // Reduce all existing factors: for (const DiscreteFactor::shared_ptr& f : factors_) { - Constraint::shared_ptr constraint = - boost::dynamic_pointer_cast(f); + auto constraint = boost::dynamic_pointer_cast(f); if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); Constraint::shared_ptr reduced = constraint->partiallyApply(domains); - if (print) reduced->print(); + if (reduced->size() > 1) { + new_csp.push_back(reduced); + } } -#endif + return new_csp; } } // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index 544cdf0c95..e7fb301156 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -20,33 +20,20 @@ namespace gtsam { */ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { public: - /** A map from keys to values */ - typedef KeyVector Indices; - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; - - public: - // /// Constructor - // CSP() { - // } + using Values = DiscreteValues; ///< backwards compatibility /// Add a unary constraint, allowing only a single value void addSingleValue(const DiscreteKey& dkey, size_t value) { - boost::shared_ptr factor(new SingleValue(dkey, value)); - push_back(factor); + emplace_shared(dkey, value); } /// Add a binary AllDiff constraint void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { - boost::shared_ptr factor(new BinaryAllDiff(key1, key2)); - push_back(factor); + emplace_shared(key1, key2); } /// Add a general AllDiff constraint - void addAllDiff(const DiscreteKeys& dkeys) { - boost::shared_ptr factor(new AllDiff(dkeys)); - push_back(factor); - } + void addAllDiff(const DiscreteKeys& dkeys) { emplace_shared(dkeys); } // /** return product of all factors as a single factor */ // DecisionTreeFactor product() const { @@ -56,11 +43,11 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { // return result; // } - /// Find the best total assignment - can be expensive - sharedValues optimalAssignment() const; + /// Find the best total assignment - can be expensive. + DiscreteValues optimalAssignment() const; - /// Find the best total assignment - can be expensive - sharedValues optimalAssignment(const Ordering& ordering) const; + /// Find the best total assignment, with given ordering - can be expensive. + DiscreteValues optimalAssignment(const Ordering& ordering) const; // /* // * Perform loopy belief propagation @@ -72,16 +59,24 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { // deep. // * It will be very expensive to exclude values that way. // */ - // void applyBeliefPropagation(size_t nrIterations = 10) const; + // void applyBeliefPropagation(size_t maxIterations = 10) const; /* * Apply arc-consistency ~ Approximate loopy belief propagation * We need to give the domains to a constraint, and it returns * a domain whose values don't conflict in the arc-consistency way. - * TODO: should get cardinality from Indices + * TODO: should get cardinality from DiscreteKeys + */ + Domains runArcConsistency(size_t cardinality, + size_t maxIterations = 10) const; + + /// Run arc consistency for all variables, return true if any domain changed. + bool runArcConsistency(const VariableIndex& index, Domains* domains) const; + + /* + * Create a new CSP, applying the given Domain constraints. */ - void runArcConsistency(size_t cardinality, size_t nrIterations = 10, - bool print = false) const; + CSP partiallyApply(const Domains& domains) const; }; // CSP } // namespace gtsam diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index b8baccff98..5c21028a0c 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -18,33 +18,37 @@ #pragma once #include +#include #include #include +#include +#include namespace gtsam { class Domain; +using Domains = std::map; /** - * Base class for discrete probabilistic factors - * The most general one is the derived DecisionTreeFactor + * Base class for constraint factors + * Derived classes include SingleValue, BinaryAllDiff, and AllDiff. */ -class Constraint : public DiscreteFactor { +class GTSAM_EXPORT Constraint : public DiscreteFactor { public: typedef boost::shared_ptr shared_ptr; protected: - /// Construct n-way factor - Constraint(const KeyVector& js) : DiscreteFactor(js) {} - - /// Construct unary factor + /// Construct unary constraint factor. Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {} - /// Construct binary factor + /// Construct binary constraint factor. Constraint(Key j1, Key j2) : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {} + /// Construct n-way constraint factor. + Constraint(const KeyVector& js) : DiscreteFactor(js) {} + /// construct from container template Constraint(KeyIterator beginKey, KeyIterator endKey) @@ -65,18 +69,28 @@ class Constraint : public DiscreteFactor { /// @{ /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - virtual bool ensureArcConsistency(size_t j, - std::vector& domains) const = 0; + virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; /// Partially apply known values - virtual shared_ptr partiallyApply(const Values&) const = 0; + virtual shared_ptr partiallyApply(const DiscreteValues&) const = 0; /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; + virtual shared_ptr partiallyApply(const Domains&) const = 0; + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { + return (boost::format("`Constraint` on %1% variables\n") % (size())).str(); + } + /// @} }; // DiscreteFactor diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index a81b1d1ad5..7acc10cb4a 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -10,29 +10,35 @@ #include #include - +#include namespace gtsam { using namespace std; /* ************************************************************************* */ void Domain::print(const string& s, const KeyFormatter& formatter) const { - // cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << - // formatter(keys_[0]) << ") with values"; - // for (size_t v: values_) cout << " " << v; - // cout << endl; - for (size_t v : values_) cout << v; + cout << s << ": Domain on " << formatter(key()) << " (j=" << formatter(key()) + << ") with values"; + for (size_t v : values_) cout << " " << v; + cout << endl; +} + +/* ************************************************************************* */ +string Domain::base1Str() const { + stringstream ss; + for (size_t v : values_) ss << v + 1; + return ss.str(); } /* ************************************************************************* */ -double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); +double Domain::operator()(const DiscreteValues& values) const { + return contains(values.at(key())); } /* ************************************************************************* */ DecisionTreeFactor Domain::toDecisionTreeFactor() const { DiscreteKeys keys; - keys += DiscreteKey(keys_[0], cardinality_); + keys += DiscreteKey(key(), cardinality_); vector table; for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); DecisionTreeFactor converted(keys, table); @@ -46,9 +52,9 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool Domain::ensureArcConsistency(size_t j, vector& domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); - Domain& D = domains[j]; +bool Domain::ensureArcConsistency(Key j, Domains* domains) const { + if (j != key()) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains->at(j); for (size_t value : values_) if (!D.contains(value)) throw runtime_error("Unsatisfiable"); D = *this; @@ -56,34 +62,33 @@ bool Domain::ensureArcConsistency(size_t j, vector& domains) const { } /* ************************************************************************* */ -bool Domain::checkAllDiff(const KeyVector keys, vector& domains) { - Key j = keys_[0]; +boost::optional Domain::checkAllDiff(const KeyVector keys, + const Domains& domains) const { + Key j = key(); // for all values in this domain - for (size_t value : values_) { + for (const size_t value : values_) { // for all connected domains - for (Key k : keys) + for (const Key k : keys) // if any domain contains the value we cannot make this domain singleton - if (k != j && domains[k].contains(value)) goto found; - values_.clear(); - values_.insert(value); - return true; // we changed it + if (k != j && domains.at(k).contains(value)) goto found; + // Otherwise: return a singleton: + return Domain(this->discreteKey(), value); found:; } - return false; // we did not change it + return boost::none; // we did not change it } /* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); +Constraint::shared_ptr Domain::partiallyApply(const DiscreteValues& values) const { + DiscreteValues::const_iterator it = values.find(key()); if (it != values.end() && !contains(it->second)) throw runtime_error("Domain::partiallyApply: unsatisfiable"); return boost::make_shared(*this); } /* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; +Constraint::shared_ptr Domain::partiallyApply(const Domains& domains) const { + const Domain& Dk = domains.at(key()); if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error("Domain::partiallyApply: unsatisfiable"); return boost::make_shared(Dk); diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 15828b6533..1047101c52 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -13,7 +13,8 @@ namespace gtsam { /** - * Domain restriction constraint + * The Domain class represents a constraint that restricts the possible values a + * particular variable, with given key, can take on. */ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { size_t cardinality_; /// Cardinality @@ -35,14 +36,16 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { values_.insert(v); } - /// Constructor - Domain(const Domain& other) - : Constraint(other.keys_[0]), values_(other.values_) {} + /// The one key + Key key() const { return keys_[0]; } - /// insert a value, non const :-( + // The associated discrete key + DiscreteKey discreteKey() const { return DiscreteKey(key(), cardinality_); } + + /// Insert a value, non const :-( void insert(size_t value) { values_.insert(value); } - /// erase a value, non const :-( + /// Erase a value, non const :-( void erase(size_t value) { values_.erase(value); } size_t nrValues() const { return values_.size(); } @@ -65,10 +68,15 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { } } + // Return concise string representation, mostly to debug arc consistency. + // Converts from base 0 to base1. + std::string base1Str() const; + + // Check whether domain cotains a specific value. bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value - double operator()(const Values& values) const override; + double operator()(const DiscreteValues& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; @@ -77,27 +85,29 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /** - * Check for a value in domain that does not occur in any other connected - * domain. If found, we make this a singleton... Called in - * AllDiff::ensureArcConsistency - * @param keys connected domains through alldiff + * Check for a value in domain that does not occur in any other connected + * domain. If found, return a a new singleton domain... + * Called in AllDiff::ensureArcConsistency + * @param keys connected domains through alldiff + * @param keys other domains */ - bool checkAllDiff(const KeyVector keys, std::vector& domains); + boost::optional checkAllDiff(const KeyVector keys, + const Domains& domains) const; /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; + Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override; /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; + Constraint::shared_ptr partiallyApply(const Domains& domains) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/Scheduler.cpp b/gtsam_unstable/discrete/Scheduler.cpp index 415f92e626..e34613c3b3 100644 --- a/gtsam_unstable/discrete/Scheduler.cpp +++ b/gtsam_unstable/discrete/Scheduler.cpp @@ -133,10 +133,10 @@ void Scheduler::addStudentSpecificConstraints(size_t i, Potentials::ADT p(dummy & areaKey, available_); // available_ is Doodle string Potentials::ADT q = p.choose(dummyIndex, *slot); - DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q)); - CSP::push_back(f); + CSP::add(areaKey, q); } else { - CSP::add(s.key_, areaKey, available_); // available_ is Doodle string + DiscreteKeys keys {s.key_, areaKey}; + CSP::add(keys, available_); // available_ is Doodle string } } @@ -202,16 +202,16 @@ void Scheduler::print(const string& s, const KeyFormatter& formatter) const { } // print /** Print readable form of assignment */ -void Scheduler::printAssignment(sharedValues assignment) const { +void Scheduler::printAssignment(const DiscreteValues& assignment) const { // Not intended to be general! Assumes very particular ordering ! cout << endl; for (size_t s = 0; s < nrStudents(); s++) { Key j = 3 * maxNrStudents_ + s; - size_t slot = assignment->at(j); + size_t slot = assignment.at(j); cout << studentName(s) << " slot: " << slotName_[slot] << endl; Key base = 3 * s; for (size_t area = 0; area < 3; area++) { - size_t faculty = assignment->at(base + area); + size_t faculty = assignment.at(base + area); cout << setw(12) << studentArea(s, area) << ": " << facultyName_[faculty] << endl; } @@ -220,8 +220,8 @@ void Scheduler::printAssignment(sharedValues assignment) const { } /** Special print for single-student case */ -void Scheduler::printSpecial(sharedValues assignment) const { - Values::const_iterator it = assignment->begin(); +void Scheduler::printSpecial(const DiscreteValues& assignment) const { + DiscreteValues::const_iterator it = assignment.begin(); for (size_t area = 0; area < 3; area++, it++) { size_t f = it->second; cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl; @@ -230,12 +230,12 @@ void Scheduler::printSpecial(sharedValues assignment) const { } /** Accumulate faculty stats */ -void Scheduler::accumulateStats(sharedValues assignment, +void Scheduler::accumulateStats(const DiscreteValues& assignment, vector& stats) const { for (size_t s = 0; s < nrStudents(); s++) { Key base = 3 * s; for (size_t area = 0; area < 3; area++) { - size_t f = assignment->at(base + area); + size_t f = assignment.at(base + area); assert(f < stats.size()); stats[f]++; } // area @@ -256,7 +256,7 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const { } /** Find the best total assignment - can be expensive */ -Scheduler::sharedValues Scheduler::optimalAssignment() const { +DiscreteValues Scheduler::optimalAssignment() const { DiscreteBayesNet::shared_ptr chordal = eliminate(); if (ISDEBUG("Scheduler::optimalAssignment")) { @@ -267,22 +267,21 @@ Scheduler::sharedValues Scheduler::optimalAssignment() const { } gttic(my_optimize); - sharedValues mpe = chordal->optimize(); + DiscreteValues mpe = chordal->optimize(); gttoc(my_optimize); return mpe; } /** find the assignment of students to slots with most possible committees */ -Scheduler::sharedValues Scheduler::bestSchedule() const { - sharedValues best; +DiscreteValues Scheduler::bestSchedule() const { + DiscreteValues best; throw runtime_error("bestSchedule not implemented"); return best; } /** find the corresponding most desirable committee assignment */ -Scheduler::sharedValues Scheduler::bestAssignment( - sharedValues bestSchedule) const { - sharedValues best; +DiscreteValues Scheduler::bestAssignment(const DiscreteValues& bestSchedule) const { + DiscreteValues best; throw runtime_error("bestAssignment not implemented"); return best; } diff --git a/gtsam_unstable/discrete/Scheduler.h b/gtsam_unstable/discrete/Scheduler.h index faf131f5c7..7559cdea6b 100644 --- a/gtsam_unstable/discrete/Scheduler.h +++ b/gtsam_unstable/discrete/Scheduler.h @@ -134,26 +134,26 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP { const KeyFormatter& formatter = DefaultKeyFormatter) const override; /** Print readable form of assignment */ - void printAssignment(sharedValues assignment) const; + void printAssignment(const DiscreteValues& assignment) const; /** Special print for single-student case */ - void printSpecial(sharedValues assignment) const; + void printSpecial(const DiscreteValues& assignment) const; /** Accumulate faculty stats */ - void accumulateStats(sharedValues assignment, + void accumulateStats(const DiscreteValues& assignment, std::vector& stats) const; /** Eliminate, return a Bayes net */ DiscreteBayesNet::shared_ptr eliminate() const; /** Find the best total assignment - can be expensive */ - sharedValues optimalAssignment() const; + DiscreteValues optimalAssignment() const; /** find the assignment of students to slots with most possible committees */ - sharedValues bestSchedule() const; + DiscreteValues bestSchedule() const; /** find the corresponding most desirable committee assignment */ - sharedValues bestAssignment(sharedValues bestSchedule) const; + DiscreteValues bestAssignment(const DiscreteValues& bestSchedule) const; }; // Scheduler diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 105887dc9e..6dd81a7dc6 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -23,7 +23,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double SingleValue::operator()(const Values& values) const { +double SingleValue::operator()(const DiscreteValues& values) const { return (double)(values.at(keys_[0]) == value_); } @@ -44,11 +44,10 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool SingleValue::ensureArcConsistency(size_t j, - vector& domains) const { +bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { if (j != keys_[0]) throw invalid_argument("SingleValue check on wrong domain"); - Domain& D = domains[j]; + Domain& D = domains->at(j); if (D.isSingleton()) { if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); return false; @@ -58,8 +57,8 @@ bool SingleValue::ensureArcConsistency(size_t j, } /* ************************************************************************* */ -Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); +Constraint::shared_ptr SingleValue::partiallyApply(const DiscreteValues& values) const { + DiscreteValues::const_iterator it = values.find(keys_[0]); if (it != values.end() && it->second != value_) throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); return boost::make_shared(keys_[0], cardinality_, value_); @@ -67,8 +66,8 @@ Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { /* ************************************************************************* */ Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; + const Domains& domains) const { + const Domain& Dk = domains.at(keys_[0]); if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); return boost::make_shared(discreteKey(), value_); diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index a2aec338c6..3b2d6e80bb 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -13,14 +13,12 @@ namespace gtsam { /** - * SingleValue constraint + * SingleValue constraint: ensures a variable takes on a certain value. + * This could of course also be implemented by changing its `Domain`. */ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { - /// Number of values - size_t cardinality_; - - /// allowed value - size_t value_; + size_t cardinality_; /// < Number of values + size_t value_; ///< allowed value DiscreteKey discreteKey() const { return DiscreteKey(keys_[0], cardinality_); @@ -29,11 +27,11 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { public: typedef boost::shared_ptr shared_ptr; - /// Constructor + /// Construct from key, cardinality, and given value. SingleValue(Key key, size_t n, size_t value) : Constraint(key), cardinality_(n), value_(value) {} - /// Constructor + /// Construct from DiscreteKey and given value. SingleValue(const DiscreteKey& dkey, size_t value) : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {} @@ -52,7 +50,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Calculate value - double operator()(const Values& values) const override; + double operator()(const DiscreteValues& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; @@ -61,19 +59,19 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency + * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; + Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override; /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; + const Domains& domains) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/examples/schedulingExample.cpp b/gtsam_unstable/discrete/examples/schedulingExample.cpp index e9f63b2d8c..3460664db7 100644 --- a/gtsam_unstable/discrete/examples/schedulingExample.cpp +++ b/gtsam_unstable/discrete/examples/schedulingExample.cpp @@ -122,7 +122,7 @@ void runLargeExample() { // SETDEBUG("timing-verbose", true); SETDEBUG("DiscreteConditional::DiscreteConditional", true); gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimalAssignment(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -165,7 +165,7 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; + DiscreteValues values; size_t bestSlot = root->solve(values); // get corresponding count @@ -225,7 +225,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 500; n++) { vector stats(19, 0); - vector samples; + vector samples; for (size_t i = 0; i < 7; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); @@ -319,7 +319,7 @@ void accomodateStudent() { // GTSAM_PRINT(*chordal); // solve root node only - Scheduler::Values values; + DiscreteValues values; size_t bestSlot = root->solve(values); // get corresponding count @@ -331,7 +331,7 @@ void accomodateStudent() { // sample schedules for (size_t n = 0; n < 10; n++) { - Scheduler::sharedValues sample0 = chordal->sample(); + auto sample0 = chordal->sample(); scheduler.printAssignment(sample0); } } diff --git a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp index 1fc4a1459b..19694c31ec 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp @@ -129,7 +129,7 @@ void runLargeExample() { tictoc_finishedIteration(); tictoc_print(); for (size_t i=0;i<100;i++) { - DiscreteFactor::sharedValues assignment = chordal->sample(); + auto assignment = chordal->sample(); vector stats(scheduler.nrFaculty()); scheduler.accumulateStats(assignment, stats); size_t max = *max_element(stats.begin(), stats.end()); @@ -143,7 +143,7 @@ void runLargeExample() { } #else gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimalAssignment(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -190,7 +190,7 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; + DiscreteValues values; size_t bestSlot = root->solve(values); // get corresponding count @@ -234,7 +234,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 500; n++) { vector stats(19, 0); - vector samples; + vector samples; for (size_t i = 0; i < NRSTUDENTS; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); diff --git a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp index 95b64f2897..4b96b1eeba 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp @@ -153,7 +153,7 @@ void runLargeExample() { tictoc_finishedIteration(); tictoc_print(); for (size_t i=0;i<100;i++) { - DiscreteFactor::sharedValues assignment = sample(*chordal); + auto assignment = sample(*chordal); vector stats(scheduler.nrFaculty()); scheduler.accumulateStats(assignment, stats); size_t max = *max_element(stats.begin(), stats.end()); @@ -167,7 +167,7 @@ void runLargeExample() { } #else gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimalAssignment(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -212,7 +212,7 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; + DiscreteValues values; size_t bestSlot = root->solve(values); // get corresponding count @@ -259,7 +259,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 10000; n++) { vector stats(nrFaculty, 0); - vector samples; + vector samples; for (size_t i = 0; i < NRSTUDENTS; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 1552fcbf14..88defd9860 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -19,12 +19,34 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE(BinaryAllDif, allInOne) { - // Create keys and ordering +TEST(CSP, SingleValue) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check that a single value is equal to a decision stump with only one "1": + SingleValue singleValue(AZ, 2); + DecisionTreeFactor f1(AZ, "0 0 1"); + EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); + + // Create domains + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + // Ensure arc-consistency: just wipes out values in AZ domain: + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + LONGS_EQUAL(3, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(3, domains.at(2).nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, BinaryAllDif) { + // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each: size_t nrColors = 2; - // DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", - // nrColors); - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Check construction and conversion BinaryAllDiff c1(ID, UT); @@ -36,16 +58,53 @@ TEST_UNSAFE(BinaryAllDif, allInOne) { DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); EXPECT(assert_equal(f2, c2.toDecisionTreeFactor())); + // Check multiplication of factors with constraint: DecisionTreeFactor f3 = f1 * f2; EXPECT(assert_equal(f3, c1 * f2)); EXPECT(assert_equal(f3, c2 * f1)); } /* ************************************************************************* */ -TEST_UNSAFE(CSP, allInOne) { - // Create keys and ordering +TEST(CSP, AllDiff) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check construction and conversion + vector dkeys{ID, UT, AZ}; + AllDiff alldiff(dkeys); + DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); + // GTSAM_PRINT(actual); + actual.dot("actual"); + DecisionTreeFactor f2( + ID & AZ & UT, + "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); + EXPECT(assert_equal(f2, actual)); + + // Create domains. + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + // First constrict AZ domain: + SingleValue singleValue(AZ, 2); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + + // Arc-consistency + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, allInOne) { + // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each: size_t nrColors = 2; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Create the CSP CSP csp; @@ -53,14 +112,14 @@ TEST_UNSAFE(CSP, allInOne) { csp.addAllDiff(UT, AZ); // Check an invalid combination, with ID==UT==AZ all same color - DiscreteFactor::Values invalid; + DiscreteValues invalid; invalid[ID.first] = 0; invalid[UT.first] = 0; invalid[AZ.first] = 0; EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // Check a valid combination - DiscreteFactor::Values valid; + DiscreteValues valid; valid[ID.first] = 0; valid[UT.first] = 1; valid[AZ.first] = 0; @@ -73,23 +132,20 @@ TEST_UNSAFE(CSP, allInOne) { EXPECT(assert_equal(expectedProduct, product)); // Solve - CSP::sharedValues mpe = csp.optimalAssignment(); - CSP::Values expected; + auto mpe = csp.optimalAssignment(); + DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); - EXPECT(assert_equal(expected, *mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + EXPECT(assert_equal(expected, mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); } /* ************************************************************************* */ -TEST_UNSAFE(CSP, WesternUS) { - // Create keys +TEST(CSP, WesternUS) { + // Create keys for all states in Western US, with 4 color possibilities. size_t nrColors = 4; - DiscreteKey - // Create ordering according to example in ND-CSP.lyx - WA(0, nrColors), - OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors), - UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors), - CO(7, nrColors), NM(6, nrColors); + DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors), + NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), + MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); // Create the CSP CSP csp; @@ -116,13 +172,14 @@ TEST_UNSAFE(CSP, WesternUS) { csp.addAllDiff(WY, CO); csp.addAllDiff(CO, NM); - // Solve + // Create ordering according to example in ND-CSP.lyx Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), Key(8), Key(9), Key(10); - CSP::sharedValues mpe = csp.optimalAssignment(ordering); - // GTSAM_PRINT(*mpe); - CSP::Values expected; + // Solve using that ordering: + auto mpe = csp.optimalAssignment(ordering); + // GTSAM_PRINT(mpe); + DiscreteValues expected; insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)( MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)( UT.first, 1)(AZ.first, 0); @@ -130,8 +187,8 @@ TEST_UNSAFE(CSP, WesternUS) { // TODO: Fix me! mpe result seems to be right. (See the printing) // It has the same prob as the expected solution. // Is mpe another solution, or the expected solution is unique??? - EXPECT(assert_equal(expected, *mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + EXPECT(assert_equal(expected, mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); // Write out the dual graph for hmetis #ifdef DUAL @@ -143,69 +200,57 @@ TEST_UNSAFE(CSP, WesternUS) { } /* ************************************************************************* */ -TEST_UNSAFE(CSP, AllDiff) { - // Create keys and ordering +TEST(CSP, ArcConsistency) { + // Create keys for Idaho, Arizona, and Utah, allowing three colors for each: size_t nrColors = 3; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); - // Create the CSP + // Create the CSP using just one all-diff constraint, plus constrain Arizona. CSP csp; - vector dkeys; - dkeys += ID, UT, AZ; + vector dkeys{ID, UT, AZ}; csp.addAllDiff(dkeys); csp.addSingleValue(AZ, 2); - // GTSAM_PRINT(csp); - - // Check construction and conversion - SingleValue s(AZ, 2); - DecisionTreeFactor f1(AZ, "0 0 1"); - EXPECT(assert_equal(f1, s.toDecisionTreeFactor())); - - // Check construction and conversion - AllDiff alldiff(dkeys); - DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); - // GTSAM_PRINT(actual); - // actual.dot("actual"); - DecisionTreeFactor f2( - ID & AZ & UT, - "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); - EXPECT(assert_equal(f2, actual)); + // GTSAM_PRINT(csp); // Check an invalid combination, with ID==UT==AZ all same color - DiscreteFactor::Values invalid; + DiscreteValues invalid; invalid[ID.first] = 0; invalid[UT.first] = 1; invalid[AZ.first] = 0; EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // Check a valid combination - DiscreteFactor::Values valid; + DiscreteValues valid; valid[ID.first] = 0; valid[UT.first] = 1; valid[AZ.first] = 2; EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Solve - CSP::sharedValues mpe = csp.optimalAssignment(); - CSP::Values expected; + auto mpe = csp.optimalAssignment(); + DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); - EXPECT(assert_equal(expected, *mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + EXPECT(assert_equal(expected, mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); - // Arc-consistency - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); + // ensure arc-consistency, i.e., narrow domains... + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + SingleValue singleValue(AZ, 2); - EXPECT(singleValue.ensureArcConsistency(1, domains)); - EXPECT(alldiff.ensureArcConsistency(0, domains)); - EXPECT(!alldiff.ensureArcConsistency(1, domains)); - EXPECT(alldiff.ensureArcConsistency(2, domains)); - LONGS_EQUAL(2, domains[0].nrValues()); - LONGS_EQUAL(1, domains[1].nrValues()); - LONGS_EQUAL(2, domains[2].nrValues()); + AllDiff alldiff(dkeys); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); // Parial application, version 1 - DiscreteFactor::Values known; + DiscreteValues known; known[AZ.first] = 2; DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0"); @@ -222,6 +267,7 @@ TEST_UNSAFE(CSP, AllDiff) { // full arc-consistency test csp.runArcConsistency(nrColors); + // GTSAM_PRINT(csp); } /* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp index c48d7639d8..6561949b14 100644 --- a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp +++ b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp @@ -126,7 +126,7 @@ class LoopyBelief { // normalize belief double sum = 0.0; for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) { - DiscreteFactor::Values val; + DiscreteValues val; val[key] = v; sum += (*beliefAtKey)(val); } diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index 4eb86fe1fd..7822cbd38b 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -122,7 +122,7 @@ TEST(schedulingExample, test) { // Do exact inference gttic(small); - DiscreteFactor::sharedValues MPE = s.optimalAssignment(); + auto MPE = s.optimalAssignment(); gttoc(small); // print MPE, commented out as unit tests don't print @@ -133,13 +133,13 @@ TEST(schedulingExample, test) { // find the assignment of students to slots with most possible committees // Commented out as not implemented yet - // sharedValues bestSchedule = s.bestSchedule(); - // GTSAM_PRINT(*bestSchedule); + // auto bestSchedule = s.bestSchedule(); + // GTSAM_PRINT(bestSchedule); // find the corresponding most desirable committee assignment // Commented out as not implemented yet - // sharedValues bestAssignment = s.bestAssignment(bestSchedule); - // GTSAM_PRINT(*bestAssignment); + // auto bestAssignment = s.bestAssignment(bestSchedule); + // GTSAM_PRINT(bestAssignment); } /* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index 4843ae2694..35f3ba8437 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include @@ -20,12 +21,12 @@ using namespace gtsam; #define PRINT false +/// A class that encodes Sudoku's as a CSP problem class Sudoku : public CSP { - /// sudoku size - size_t n_; + size_t n_; ///< Side of Sudoku, e.g. 4 or 9 - /// discrete keys - typedef std::pair IJ; + /// Mapping from base i,j coordinates to discrete keys: + using IJ = std::pair; std::map dkeys_; public: @@ -42,15 +43,14 @@ class Sudoku : public CSP { // Create variables, ordering, and unary constraints va_list ap; va_start(ap, n); - Key k = 0; for (size_t i = 0; i < n; ++i) { - for (size_t j = 0; j < n; ++j, ++k) { + for (size_t j = 0; j < n; ++j) { // create the key IJ ij(i, j); - dkeys_[ij] = DiscreteKey(k, n); + Symbol key('1' + i, j + 1); + dkeys_[ij] = DiscreteKey(key, n); // get the unary constraint, if any int value = va_arg(ap, int); - // cout << value << " "; if (value != 0) addSingleValue(dkeys_[ij], value - 1); } // cout << endl; @@ -88,111 +88,171 @@ class Sudoku : public CSP { } /// Print readable form of assignment - void printAssignment(DiscreteFactor::sharedValues assignment) const { + void printAssignment(const DiscreteValues& assignment) const { for (size_t i = 0; i < n_; i++) { for (size_t j = 0; j < n_; j++) { Key k = key(i, j); - cout << 1 + assignment->at(k) << " "; + cout << 1 + assignment.at(k) << " "; } cout << endl; } } /// solve and print solution - void printSolution() { - DiscreteFactor::sharedValues MPE = optimalAssignment(); + void printSolution() const { + auto MPE = optimalAssignment(); printAssignment(MPE); } + + // Print domain + void printDomains(const Domains& domains) { + for (size_t i = 0; i < n_; i++) { + for (size_t j = 0; j < n_; j++) { + Key k = key(i, j); + cout << domains.at(k).base1Str(); + cout << "\t"; + } // i + cout << endl; + } // j + } }; /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, small) { +TEST(Sudoku, small) { Sudoku csp(4, // 1, 0, 0, 4, // 0, 0, 0, 0, // 4, 0, 2, 0, // 0, 1, 0, 0); - // Do BP - csp.runArcConsistency(4, 10, PRINT); - // optimize and check - CSP::sharedValues solution = csp.optimalAssignment(); - CSP::Values expected; + auto solution = csp.optimalAssignment(); + DiscreteValues expected; insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)( csp.key(2, 3), 0)(csp.key(3, 0), 1)(csp.key(3, 1), 0)(csp.key(3, 2), 3)( csp.key(3, 3), 2); - EXPECT(assert_equal(expected, *solution)); + EXPECT(assert_equal(expected, solution)); // csp.printAssignment(solution); + + // Do BP (AC1) + auto domains = csp.runArcConsistency(4, 3); + // csp.printDomains(domains); + Domain domain44 = domains.at(Symbol('4', 4)); + EXPECT_LONGS_EQUAL(1, domain44.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Should only be 16 new Domains + EXPECT_LONGS_EQUAL(16, new_csp.size()); + + // Check that solution + auto new_solution = new_csp.optimalAssignment(); + // csp.printAssignment(new_solution); + EXPECT(assert_equal(expected, new_solution)); } /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, easy) { - Sudoku sudoku(9, // - 0, 0, 5, 0, 9, 0, 0, 0, 1, // - 0, 0, 0, 0, 0, 2, 0, 7, 3, // - 7, 6, 0, 0, 0, 8, 2, 0, 0, // - - 0, 1, 2, 0, 0, 9, 0, 0, 4, // - 0, 0, 0, 2, 0, 3, 0, 0, 0, // - 3, 0, 0, 1, 0, 0, 9, 6, 0, // - - 0, 0, 1, 9, 0, 0, 0, 5, 8, // - 9, 7, 0, 5, 0, 0, 0, 0, 0, // - 5, 0, 0, 0, 3, 0, 7, 0, 0); - - // Do BP - sudoku.runArcConsistency(4, 10, PRINT); - - // sudoku.printSolution(); // don't do it +TEST(Sudoku, easy) { + Sudoku csp(9, // + 0, 0, 5, 0, 9, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 2, 0, 7, 3, // + 7, 6, 0, 0, 0, 8, 2, 0, 0, // + + 0, 1, 2, 0, 0, 9, 0, 0, 4, // + 0, 0, 0, 2, 0, 3, 0, 0, 0, // + 3, 0, 0, 1, 0, 0, 9, 6, 0, // + + 0, 0, 1, 9, 0, 0, 0, 5, 8, // + 9, 7, 0, 5, 0, 0, 0, 0, 0, // + 5, 0, 0, 0, 3, 0, 7, 0, 0); + + // csp.printSolution(); // don't do it + + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 26 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 26, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, extreme) { - Sudoku sudoku(9, // - 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // - 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // - 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // - 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // - 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // - 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // - 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // - 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); +TEST(Sudoku, extreme) { + Sudoku csp(9, // + 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // + 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // + 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // + 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // + 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // + 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); // Do BP - sudoku.runArcConsistency(9, 10, PRINT); + csp.runArcConsistency(9, 10); #ifdef METIS - VariableIndexOrdered index(sudoku); + VariableIndexOrdered index(csp); index.print("index"); ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt"); index.outputMetisFormat(os); #endif - // sudoku.printSolution(); // don't do it -} - -/* ************************************************************************* */ -TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) { - Sudoku sudoku(9, // - 9, 5, 0, 0, 0, 6, 0, 0, 0, // - 0, 8, 4, 0, 7, 0, 0, 0, 0, // - 6, 2, 0, 5, 0, 0, 4, 0, 0, // + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(2, domain99.nrValues()); - 0, 0, 0, 2, 9, 0, 6, 0, 0, // - 0, 9, 0, 0, 0, 0, 0, 2, 0, // - 0, 0, 2, 0, 6, 3, 0, 0, 0, // + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 20 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 20, new_csp.size()); - 0, 0, 9, 0, 0, 7, 0, 6, 8, // - 0, 0, 0, 0, 3, 0, 2, 9, 0, // - 0, 0, 0, 1, 0, 0, 0, 3, 7); - - // Do BP - sudoku.runArcConsistency(9, 10, PRINT); + // csp.printSolution(); // still don't do it ! :-( +} - // sudoku.printSolution(); // don't do it +/* ************************************************************************* */ +TEST(Sudoku, AJC_3star_Feb8_2012) { + Sudoku csp(9, // + 9, 5, 0, 0, 0, 6, 0, 0, 0, // + 0, 8, 4, 0, 7, 0, 0, 0, 0, // + 6, 2, 0, 5, 0, 0, 4, 0, 0, // + + 0, 0, 0, 2, 9, 0, 6, 0, 0, // + 0, 9, 0, 0, 0, 0, 0, 2, 0, // + 0, 0, 2, 0, 6, 3, 0, 0, 0, // + + 0, 0, 9, 0, 0, 7, 0, 6, 8, // + 0, 0, 0, 0, 3, 0, 2, 9, 0, // + 0, 0, 0, 1, 0, 0, 0, 3, 7); + + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Just the 81 new Domains + EXPECT_LONGS_EQUAL(81, new_csp.size()); + + // Check that solution + auto solution = new_csp.optimalAssignment(); + // csp.printAssignment(solution); + EXPECT_LONGS_EQUAL(6, solution.at(key99)); } /* ************************************************************************* */ diff --git a/gtsam_unstable/slam/LocalOrientedPlane3Factor.h b/gtsam_unstable/slam/LocalOrientedPlane3Factor.h index 5264c8f4ba..7f71b282b0 100644 --- a/gtsam_unstable/slam/LocalOrientedPlane3Factor.h +++ b/gtsam_unstable/slam/LocalOrientedPlane3Factor.h @@ -9,6 +9,8 @@ #include #include +#include + #include namespace gtsam { @@ -32,9 +34,9 @@ namespace gtsam { * a local linearisation point for the plane. The plane is representated and * optimized in x1 frame in the optimization. */ -class LocalOrientedPlane3Factor: public NoiseModelFactor3 { -protected: +class GTSAM_UNSTABLE_EXPORT LocalOrientedPlane3Factor + : public NoiseModelFactor3 { + protected: OrientedPlane3 measured_p_; typedef NoiseModelFactor3 Base; public: diff --git a/gtsam_unstable/slam/PoseToPointFactor.h b/gtsam_unstable/slam/PoseToPointFactor.h index ec7da22ef1..cab48e5069 100644 --- a/gtsam_unstable/slam/PoseToPointFactor.h +++ b/gtsam_unstable/slam/PoseToPointFactor.h @@ -1,11 +1,14 @@ /** * @file PoseToPointFactor.hpp - * @brief This factor can be used to track a 3D landmark over time by - *providing local measurements of its location. + * @brief This factor can be used to model relative position measurements + * from a (2D or 3D) pose to a landmark * @author David Wisth + * @author Luca Carlone **/ #pragma once +#include +#include #include #include #include @@ -17,12 +20,13 @@ namespace gtsam { * A class for a measurement between a pose and a point. * @addtogroup SLAM */ -class PoseToPointFactor : public NoiseModelFactor2 { +template +class PoseToPointFactor : public NoiseModelFactor2 { private: typedef PoseToPointFactor This; - typedef NoiseModelFactor2 Base; + typedef NoiseModelFactor2 Base; - Point3 measured_; /** the point measurement in local coordinates */ + POINT measured_; /** the point measurement in local coordinates */ public: // shorthand for a smart pointer to a factor @@ -32,7 +36,7 @@ class PoseToPointFactor : public NoiseModelFactor2 { PoseToPointFactor() {} /** Constructor */ - PoseToPointFactor(Key key1, Key key2, const Point3& measured, + PoseToPointFactor(Key key1, Key key2, const POINT& measured, const SharedNoiseModel& model) : Base(model, key1, key2), measured_(measured) {} @@ -41,8 +45,8 @@ class PoseToPointFactor : public NoiseModelFactor2 { /** implement functions needed for Testable */ /** print */ - virtual void print(const std::string& s, const KeyFormatter& keyFormatter = - DefaultKeyFormatter) const { + void print(const std::string& s, const KeyFormatter& keyFormatter = + DefaultKeyFormatter) const override { std::cout << s << "PoseToPointFactor(" << keyFormatter(this->key1()) << "," << keyFormatter(this->key2()) << ")\n" << " measured: " << measured_.transpose() << std::endl; @@ -50,30 +54,31 @@ class PoseToPointFactor : public NoiseModelFactor2 { } /** equals */ - virtual bool equals(const NonlinearFactor& expected, - double tol = 1e-9) const { + bool equals(const NonlinearFactor& expected, + double tol = 1e-9) const override { const This* e = dynamic_cast(&expected); return e != nullptr && Base::equals(*e, tol) && - traits::Equals(this->measured_, e->measured_, tol); + traits::Equals(this->measured_, e->measured_, tol); } /** implement functions needed to derive from Factor */ /** vector of errors - * @brief Error = wTwi.inverse()*wPwp - measured_ - * @param wTwi The pose of the sensor in world coordinates - * @param wPwp The estimated point location in world coordinates + * @brief Error = w_T_b.inverse()*w_P - measured_ + * @param w_T_b The pose of the body in world coordinates + * @param w_P The estimated point location in world coordinates * * Note: measured_ and the error are in local coordiantes. */ - Vector evaluateError(const Pose3& wTwi, const Point3& wPwp, - boost::optional H1 = boost::none, - boost::optional H2 = boost::none) const { - return wTwi.transformTo(wPwp, H1, H2) - measured_; + Vector evaluateError( + const POSE& w_T_b, const POINT& w_P, + boost::optional H1 = boost::none, + boost::optional H2 = boost::none) const override { + return w_T_b.transformTo(w_P, H1, H2) - measured_; } /** return the measured */ - const Point3& measured() const { return measured_; } + const POINT& measured() const { return measured_; } private: /** Serialization function */ diff --git a/gtsam_unstable/slam/ProjectionFactorPPPC.h b/gtsam_unstable/slam/ProjectionFactorPPPC.h index fbc11503c5..53860efdc1 100644 --- a/gtsam_unstable/slam/ProjectionFactorPPPC.h +++ b/gtsam_unstable/slam/ProjectionFactorPPPC.h @@ -18,9 +18,11 @@ #pragma once -#include -#include #include +#include +#include +#include + #include namespace gtsam { @@ -30,28 +32,27 @@ namespace gtsam { * estimates the body pose, body-camera transform, 3D landmark, and calibration. * @addtogroup SLAM */ - template - class ProjectionFactorPPPC: public NoiseModelFactor4 { - protected: - - Point2 measured_; ///< 2D measurement +template +class GTSAM_UNSTABLE_EXPORT ProjectionFactorPPPC + : public NoiseModelFactor4 { + protected: + Point2 measured_; ///< 2D measurement - // verbosity handling for Cheirality Exceptions - bool throwCheirality_; ///< If true, rethrows Cheirality exceptions (default: false) - bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) + // verbosity handling for Cheirality Exceptions + bool throwCheirality_; ///< If true, rethrows Cheirality exceptions (default: false) + bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) - public: + public: + /// shorthand for base class type + typedef NoiseModelFactor4 Base; - /// shorthand for base class type - typedef NoiseModelFactor4 Base; + /// shorthand for this class + typedef ProjectionFactorPPPC This; - /// shorthand for this class - typedef ProjectionFactorPPPC This; + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; - /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; - - /// Default constructor + /// Default constructor ProjectionFactorPPPC() : measured_(0.0, 0.0), throwCheirality_(false), verboseCheirality_(false) { } @@ -168,7 +169,7 @@ namespace gtsam { ar & BOOST_SERIALIZATION_NVP(throwCheirality_); ar & BOOST_SERIALIZATION_NVP(verboseCheirality_); } - }; +}; /// traits template diff --git a/gtsam_unstable/slam/ProjectionFactorRollingShutter.h b/gtsam_unstable/slam/ProjectionFactorRollingShutter.h index c92653c13e..2aeaa48249 100644 --- a/gtsam_unstable/slam/ProjectionFactorRollingShutter.h +++ b/gtsam_unstable/slam/ProjectionFactorRollingShutter.h @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -40,7 +41,7 @@ namespace gtsam { * @addtogroup SLAM */ -class ProjectionFactorRollingShutter +class GTSAM_UNSTABLE_EXPORT ProjectionFactorRollingShutter : public NoiseModelFactor3 { protected: // Keep a copy of measurement and calibration for I/O diff --git a/gtsam_unstable/slam/ReadMe.md b/gtsam_unstable/slam/README.md similarity index 100% rename from gtsam_unstable/slam/ReadMe.md rename to gtsam_unstable/slam/README.md diff --git a/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h b/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h index 23203be67b..ff84fcd16a 100644 --- a/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h +++ b/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h @@ -20,6 +20,7 @@ #include #include +#include namespace gtsam { /** @@ -41,12 +42,14 @@ namespace gtsam { * @addtogroup SLAM */ template -class SmartProjectionPoseFactorRollingShutter +class GTSAM_UNSTABLE_EXPORT SmartProjectionPoseFactorRollingShutter : public SmartProjectionFactor { private: typedef SmartProjectionFactor Base; typedef SmartProjectionPoseFactorRollingShutter This; typedef typename CAMERA::CalibrationType CALIBRATION; + typedef typename CAMERA::Measurement MEASUREMENT; + typedef typename CAMERA::MeasurementVector MEASUREMENTS; protected: /// The keys of the pose of the body (with respect to an external world @@ -68,12 +71,6 @@ class SmartProjectionPoseFactorRollingShutter public: EIGEN_MAKE_ALIGNED_OPERATOR_NEW - typedef CAMERA Camera; - typedef CameraSet Cameras; - - /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; - static const int DimBlock = 12; ///< size of the variable stacking 2 poses from which the observation ///< pose is interpolated @@ -84,6 +81,12 @@ class SmartProjectionPoseFactorRollingShutter typedef std::vector> FBlocks; // vector of F blocks + typedef CAMERA Camera; + typedef CameraSet Cameras; + + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; + /// Default constructor, only for serialization SmartProjectionPoseFactorRollingShutter() {} @@ -125,7 +128,7 @@ class SmartProjectionPoseFactorRollingShutter * interpolated pose is the same as world_P_body_key1 * @param cameraId ID of the camera taking the measurement (default 0) */ - void add(const Point2& measured, const Key& world_P_body_key1, + void add(const MEASUREMENT& measured, const Key& world_P_body_key1, const Key& world_P_body_key2, const double& alpha, const size_t& cameraId = 0) { // store measurements in base class @@ -164,7 +167,7 @@ class SmartProjectionPoseFactorRollingShutter * @param cameraIds IDs of the cameras taking each measurement (same order as * the measurements) */ - void add(const Point2Vector& measurements, + void add(const MEASUREMENTS& measurements, const std::vector>& world_P_body_key_pairs, const std::vector& alphas, const FastVector& cameraIds = FastVector()) { @@ -330,12 +333,13 @@ class SmartProjectionPoseFactorRollingShutter const typename Base::Camera& camera_i = (*cameraRig_)[cameraIds_[i]]; auto body_P_cam = camera_i.pose(); auto w_P_cam = w_P_body.compose(body_P_cam, dPoseCam_dInterpPose); - PinholeCamera camera(w_P_cam, camera_i.calibration()); + typename Base::Camera camera( + w_P_cam, make_shared( + camera_i.calibration())); // get jacobians and error vector for current measurement - Point2 reprojectionError_i = - Point2(camera.project(*this->result_, dProject_dPoseCam, Ei) - - this->measured_.at(i)); + Point2 reprojectionError_i = camera.reprojectionError( + *this->result_, this->measured_.at(i), dProject_dPoseCam, Ei); Eigen::Matrix J; // 2 x 12 J.block(0, 0, ZDim, 6) = dProject_dPoseCam * dPoseCam_dInterpPose * @@ -403,7 +407,7 @@ class SmartProjectionPoseFactorRollingShutter for (size_t i = 0; i < Fs.size(); i++) Fs[i] = this->noiseModel_->Whiten(Fs[i]); - Matrix3 P = Base::Cameras::PointCov(E, lambda, diagonalDamping); + Matrix3 P = Cameras::PointCov(E, lambda, diagonalDamping); // Collect all the key pairs: these are the keys that correspond to the // blocks in Fs (on which we apply the Schur Complement) diff --git a/gtsam_unstable/slam/SmartStereoProjectionFactor.h b/gtsam_unstable/slam/SmartStereoProjectionFactor.h index 88e1129981..5cdfb2ab7c 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionFactor.h +++ b/gtsam_unstable/slam/SmartStereoProjectionFactor.h @@ -20,18 +20,18 @@ #pragma once -#include -#include - -#include #include #include -#include +#include #include +#include +#include +#include #include +#include -#include #include +#include #include namespace gtsam { @@ -49,8 +49,9 @@ typedef SmartProjectionParams SmartStereoProjectionParams; * If you'd like to store poses in values instead of cameras, use * SmartStereoProjectionPoseFactor instead */ -class SmartStereoProjectionFactor: public SmartFactorBase { -private: +class GTSAM_UNSTABLE_EXPORT SmartStereoProjectionFactor + : public SmartFactorBase { + private: typedef SmartFactorBase Base; diff --git a/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h b/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h index ce6df15cb7..e20241a0ee 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h +++ b/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h @@ -40,7 +40,8 @@ namespace gtsam { * are Pose3 variables). * @addtogroup SLAM */ -class SmartStereoProjectionFactorPP : public SmartStereoProjectionFactor { +class GTSAM_UNSTABLE_EXPORT SmartStereoProjectionFactorPP + : public SmartStereoProjectionFactor { protected: /// shared pointer to calibration object (one for each camera) std::vector> K_all_; @@ -294,7 +295,6 @@ class SmartStereoProjectionFactorPP : public SmartStereoProjectionFactor { ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar & BOOST_SERIALIZATION_NVP(K_all_); } - }; // end of class declaration diff --git a/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h b/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h index 2a8180ac51..a46000a686 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h +++ b/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h @@ -43,7 +43,8 @@ namespace gtsam { * This factor requires that values contains the involved poses (Pose3). * @addtogroup SLAM */ -class SmartStereoProjectionPoseFactor : public SmartStereoProjectionFactor { +class GTSAM_UNSTABLE_EXPORT SmartStereoProjectionPoseFactor + : public SmartStereoProjectionFactor { protected: /// shared pointer to calibration object (one for each camera) std::vector> K_all_; diff --git a/gtsam_unstable/slam/tests/testPoseToPointFactor.cpp b/gtsam_unstable/slam/tests/testPoseToPointFactor.cpp new file mode 100644 index 0000000000..5aaaaec531 --- /dev/null +++ b/gtsam_unstable/slam/tests/testPoseToPointFactor.cpp @@ -0,0 +1,161 @@ +/** + * @file testPoseToPointFactor.cpp + * @brief + * @author David Wisth + * @author Luca Carlone + * @date June 20, 2020 + */ + +#include +#include +#include + +using namespace gtsam; +using namespace gtsam::noiseModel; + +/* ************************************************************************* */ +// Verify zero error when there is no noise +TEST(PoseToPointFactor, errorNoiseless_2D) { + Pose2 pose = Pose2::identity(); + Point2 point(1.0, 2.0); + Point2 noise(0.0, 0.0); + Point2 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(2, 0.05)); + Vector expectedError = Vector2(0.0, 0.0); + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Verify expected error in test scenario +TEST(PoseToPointFactor, errorNoise_2D) { + Pose2 pose = Pose2::identity(); + Point2 point(1.0, 2.0); + Point2 noise(-1.0, 0.5); + Point2 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(2, 0.05)); + Vector expectedError = -noise; + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Check Jacobians are correct +TEST(PoseToPointFactor, jacobian_2D) { + // Measurement + gtsam::Point2 l_meas(1, 2); + + // Linearisation point + gtsam::Point2 p_t(-5, 12); + gtsam::Rot2 p_R(1.5 * M_PI); + Pose2 p(p_R, p_t); + + gtsam::Point2 l(3, 0); + + // Factor + Key pose_key(1); + Key point_key(2); + SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector2(0.1, 0.1)); + PoseToPointFactor factor(pose_key, point_key, l_meas, noise); + + // Calculate numerical derivatives + auto f = std::bind(&PoseToPointFactor::evaluateError, factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none); + Matrix numerical_H1 = numericalDerivative21(f, p, l); + Matrix numerical_H2 = numericalDerivative22(f, p, l); + + // Use the factor to calculate the derivative + Matrix actual_H1; + Matrix actual_H2; + factor.evaluateError(p, l, actual_H1, actual_H2); + + // Verify we get the expected error + EXPECT(assert_equal(numerical_H1, actual_H1, 1e-8)); + EXPECT(assert_equal(numerical_H2, actual_H2, 1e-8)); +} + +/* ************************************************************************* */ +// Verify zero error when there is no noise +TEST(PoseToPointFactor, errorNoiseless_3D) { + Pose3 pose = Pose3::identity(); + Point3 point(1.0, 2.0, 3.0); + Point3 noise(0.0, 0.0, 0.0); + Point3 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(3, 0.05)); + Vector expectedError = Vector3(0.0, 0.0, 0.0); + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Verify expected error in test scenario +TEST(PoseToPointFactor, errorNoise_3D) { + Pose3 pose = Pose3::identity(); + Point3 point(1.0, 2.0, 3.0); + Point3 noise(-1.0, 0.5, 0.3); + Point3 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(3, 0.05)); + Vector expectedError = -noise; + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Check Jacobians are correct +TEST(PoseToPointFactor, jacobian_3D) { + // Measurement + gtsam::Point3 l_meas = gtsam::Point3(1, 2, 3); + + // Linearisation point + gtsam::Point3 p_t = gtsam::Point3(-5, 12, 2); + gtsam::Rot3 p_R = gtsam::Rot3::RzRyRx(1.5 * M_PI, -0.3 * M_PI, 0.4 * M_PI); + Pose3 p(p_R, p_t); + + gtsam::Point3 l = gtsam::Point3(3, 0, 5); + + // Factor + Key pose_key(1); + Key point_key(2); + SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector3(0.1, 0.1, 0.1)); + PoseToPointFactor factor(pose_key, point_key, l_meas, noise); + + // Calculate numerical derivatives + auto f = std::bind(&PoseToPointFactor::evaluateError, factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none); + Matrix numerical_H1 = numericalDerivative21(f, p, l); + Matrix numerical_H2 = numericalDerivative22(f, p, l); + + // Use the factor to calculate the derivative + Matrix actual_H1; + Matrix actual_H2; + factor.evaluateError(p, l, actual_H1, actual_H2); + + // Verify we get the expected error + EXPECT(assert_equal(numerical_H1, actual_H1, 1e-8)); + EXPECT(assert_equal(numerical_H2, actual_H2, 1e-8)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam_unstable/slam/tests/testPoseToPointFactor.h b/gtsam_unstable/slam/tests/testPoseToPointFactor.h deleted file mode 100644 index e0e5c45817..0000000000 --- a/gtsam_unstable/slam/tests/testPoseToPointFactor.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * @file testPoseToPointFactor.cpp - * @brief - * @author David Wisth - * @date June 20, 2020 - */ - -#include -#include -#include - -using namespace gtsam; -using namespace gtsam::noiseModel; - -/// Verify zero error when there is no noise -TEST(PoseToPointFactor, errorNoiseless) { - Pose3 pose = Pose3::identity(); - Point3 point(1.0, 2.0, 3.0); - Point3 noise(0.0, 0.0, 0.0); - Point3 measured = t + noise; - - Key pose_key(1); - Key point_key(2); - PoseToPointFactor factor(pose_key, point_key, measured, - Isotropic::Sigma(3, 0.05)); - Vector expectedError = Vector3(0.0, 0.0, 0.0); - Vector actualError = factor.evaluateError(pose, point); - EXPECT(assert_equal(expectedError, actualError, 1E-5)); -} - -/// Verify expected error in test scenario -TEST(PoseToPointFactor, errorNoise) { - Pose3 pose = Pose3::identity(); - Point3 point(1.0, 2.0, 3.0); - Point3 noise(-1.0, 0.5, 0.3); - Point3 measured = t + noise; - - Key pose_key(1); - Key point_key(2); - PoseToPointFactor factor(pose_key, point_key, measured, - Isotropic::Sigma(3, 0.05)); - Vector expectedError = noise; - Vector actualError = factor.evaluateError(pose, point); - EXPECT(assert_equal(expectedError, actualError, 1E-5)); -} - -/// Check Jacobians are correct -TEST(PoseToPointFactor, jacobian) { - // Measurement - gtsam::Point3 l_meas = gtsam::Point3(1, 2, 3); - - // Linearisation point - gtsam::Point3 p_t = gtsam::Point3(-5, 12, 2); - gtsam::Rot3 p_R = gtsam::Rot3::RzRyRx(1.5 * M_PI, -0.3 * M_PI, 0.4 * M_PI); - Pose3 p(p_R, p_t); - - gtsam::Point3 l = gtsam::Point3(3, 0, 5); - - // Factor - Key pose_key(1); - Key point_key(2); - SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector3(0.1, 0.1, 0.1)); - PoseToPointFactor factor(pose_key, point_key, l_meas, noise); - - // Calculate numerical derivatives - auto f = std::bind(&PoseToPointFactor::evaluateError, factor, _1, _2, - boost::none, boost::none); - Matrix numerical_H1 = numericalDerivative21(f, p, l); - Matrix numerical_H2 = numericalDerivative22(f, p, l); - - // Use the factor to calculate the derivative - Matrix actual_H1; - Matrix actual_H2; - factor.evaluateError(p, l, actual_H1, actual_H2); - - // Verify we get the expected error - EXPECT_TRUE(assert_equal(numerical_H1, actual_H1, 1e-8)); - EXPECT_TRUE(assert_equal(numerical_H2, actual_H2, 1e-8)); -} - -/* ************************************************************************* */ -int main() { - TestResult tr; - return TestRegistry::runAllTests(tr); -} -/* ************************************************************************* */ diff --git a/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp b/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp index c17ad7e1ce..b5962d777b 100644 --- a/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp +++ b/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp @@ -1317,10 +1317,10 @@ TEST(SmartProjectionPoseFactorRollingShutter, #ifndef DISABLE_TIMING #include //-Total: 0 CPU (0 times, 0 wall, 0.21 children, min: 0 max: 0) -//| -SF RS LINEARIZE: 0.09 CPU -// (10000 times, 0.124106 wall, 0.09 children, min: 0 max: 0) -//| -RS LINEARIZE: 0.09 CPU -// (10000 times, 0.068719 wall, 0.09 children, min: 0 max: 0) +//| -SF RS LINEARIZE: 0.14 CPU +//(10000 times, 0.131202 wall, 0.14 children, min: 0 max: 0) +//| -RS LINEARIZE: 0.06 CPU +//(10000 times, 0.066951 wall, 0.06 children, min: 0 max: 0) /* *************************************************************************/ TEST(SmartProjectionPoseFactorRollingShutter, timing) { using namespace vanillaPose; @@ -1384,6 +1384,105 @@ TEST(SmartProjectionPoseFactorRollingShutter, timing) { } #endif +#include +/* ************************************************************************* */ +// spherical Camera with rolling shutter effect +namespace sphericalCameraRS { +typedef SphericalCamera Camera; +typedef CameraSet Cameras; +typedef SmartProjectionPoseFactorRollingShutter SmartFactorRS_spherical; +Pose3 interp_pose1 = interpolate(level_pose, pose_right, interp_factor1); +Pose3 interp_pose2 = interpolate(pose_right, pose_above, interp_factor2); +Pose3 interp_pose3 = interpolate(pose_above, level_pose, interp_factor3); +static EmptyCal::shared_ptr emptyK(new EmptyCal()); +Camera cam1(interp_pose1, emptyK); +Camera cam2(interp_pose2, emptyK); +Camera cam3(interp_pose3, emptyK); +} // namespace sphericalCameraRS + +/* *************************************************************************/ +TEST(SmartProjectionPoseFactorRollingShutter, + optimization_3poses_sphericalCameras) { + using namespace sphericalCameraRS; + std::vector measurements_lmk1, measurements_lmk2, measurements_lmk3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, + measurements_lmk1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, + measurements_lmk2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, + measurements_lmk3); + + // create inputs + std::vector> key_pairs; + key_pairs.push_back(std::make_pair(x1, x2)); + key_pairs.push_back(std::make_pair(x2, x3)); + key_pairs.push_back(std::make_pair(x3, x1)); + + std::vector interp_factors; + interp_factors.push_back(interp_factor1); + interp_factors.push_back(interp_factor2); + interp_factors.push_back(interp_factor3); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with RS factors + params.setRankTolerance(0.1); + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), emptyK)); + + SmartFactorRS_spherical::shared_ptr smartFactor1( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, key_pairs, interp_factors); + + SmartFactorRS_spherical::shared_ptr smartFactor2( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, key_pairs, interp_factors); + + SmartFactorRS_spherical::shared_ptr smartFactor3( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, key_pairs, interp_factors); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT( // check that the pose is actually noisy + assert_equal(Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, + -0.0313952598, -0.000986635786, 0.0314107591, + -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(pose_above, result.at(x3), 1e-6)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/matlab/+gtsam/VisualISAMInitialize.m b/matlab/+gtsam/VisualISAMInitialize.m index 29f8b3b46f..9b834e3e13 100644 --- a/matlab/+gtsam/VisualISAMInitialize.m +++ b/matlab/+gtsam/VisualISAMInitialize.m @@ -12,11 +12,11 @@ isam = ISAM2(params); %% Set Noise parameters -noiseModels.pose = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]'); +noiseModels.pose = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]', true); %noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]'); -noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.05 0.05 0.05 0.2 0.2 0.2]'); -noiseModels.point = noiseModel.Isotropic.Sigma(3, 0.1); -noiseModels.measurement = noiseModel.Isotropic.Sigma(2, 1.0); +noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.05 0.05 0.05 0.2 0.2 0.2]', true); +noiseModels.point = noiseModel.Isotropic.Sigma(3, 0.1, true); +noiseModels.measurement = noiseModel.Isotropic.Sigma(2, 1.0, true); %% Add constraints/priors % TODO: should not be from ground truth! diff --git a/matlab/CMakeLists.txt b/matlab/CMakeLists.txt index 28e7cce6e4..749ad870ac 100644 --- a/matlab/CMakeLists.txt +++ b/matlab/CMakeLists.txt @@ -64,8 +64,21 @@ set(ignore gtsam::Point3 gtsam::CustomFactor) +set(interface_files + ${GTSAM_SOURCE_DIR}/gtsam/gtsam.i + ${GTSAM_SOURCE_DIR}/gtsam/base/base.i + ${GTSAM_SOURCE_DIR}/gtsam/basis/basis.i + ${GTSAM_SOURCE_DIR}/gtsam/geometry/geometry.i + ${GTSAM_SOURCE_DIR}/gtsam/linear/linear.i + ${GTSAM_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i + ${GTSAM_SOURCE_DIR}/gtsam/symbolic/symbolic.i + ${GTSAM_SOURCE_DIR}/gtsam/sam/sam.i + ${GTSAM_SOURCE_DIR}/gtsam/slam/slam.i + ${GTSAM_SOURCE_DIR}/gtsam/sfm/sfm.i + ${GTSAM_SOURCE_DIR}/gtsam/navigation/navigation.i +) # Wrap -matlab_wrap(${GTSAM_SOURCE_DIR}/gtsam/gtsam.i "${GTSAM_ADDITIONAL_LIBRARIES}" +matlab_wrap("${interface_files}" "gtsam" "${GTSAM_ADDITIONAL_LIBRARIES}" "" "${mexFlags}" "${ignore}") # Wrap version for gtsam_unstable diff --git a/matlab/gtsam_tests/testUtilities.m b/matlab/gtsam_tests/testUtilities.m index da8dec7894..2bfe81a833 100644 --- a/matlab/gtsam_tests/testUtilities.m +++ b/matlab/gtsam_tests/testUtilities.m @@ -45,3 +45,12 @@ CHECK('size==3', actual.size==3); CHECK('actual.count(x1)', actual.count(x1)); +% test extractVectors +values = Values(); +values.insert(symbol('x', 0), (1:6)'); +values.insert(symbol('x', 1), (7:12)'); +values.insert(symbol('x', 2), (13:18)'); +values.insert(symbol('x', 7), Pose3()); +actual = utilities.extractVectors(values, 'x'); +expected = reshape(1:18, 6, 3)'; +CHECK('extractVectors', all(actual == expected, 'all')); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index e2444a51af..d3b20e3126 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -49,11 +49,13 @@ set(ignore gtsam::Pose3Vector gtsam::KeyVector gtsam::BinaryMeasurementsUnit3 + gtsam::DiscreteKey gtsam::KeyPairDoubleMap) set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i ${PROJECT_SOURCE_DIR}/gtsam/base/base.i + ${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i ${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i ${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i ${PROJECT_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i diff --git a/python/gtsam/notebooks/DiscreteBayesTree.ipynb b/python/gtsam/notebooks/DiscreteBayesTree.ipynb new file mode 100644 index 0000000000..066c31d6a8 --- /dev/null +++ b/python/gtsam/notebooks/DiscreteBayesTree.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The Discrete Bayes Tree\n", + "\n", + "An example of building a Bayes net, then eliminating it into a Bayes tree. Mirrors the code in `testDiscreteBayesTree.cpp` .\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from gtsam import DiscreteBayesTree, DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering\n", + "from gtsam.symbol_shorthand import S\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def P(*args):\n", + " \"\"\" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.\"\"\"\n", + " #TODO: We can make life easier by providing variable argument functions in C++ itself.\n", + " dks = DiscreteKeys()\n", + " for key in args:\n", + " dks.push_back(key)\n", + " return dks" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "class show(graphviz.Source):\n", + " \"\"\" Display an object with a dot method as a graph.\"\"\"\n", + "\n", + " def __init__(self, obj):\n", + " \"\"\"Construct from object with 'dot' method.\"\"\"\n", + " # This small class takes an object, calls its dot function, and uses the\n", + " # resulting string to initialize a graphviz.Source instance. This in turn\n", + " # has a _repr_mimebundle_ method, which then renders it in the notebook.\n", + " super().__init__(obj.dot())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n8\n\n8\n\n\n\n0\n\n0\n\n\n\n8->0\n\n\n\n\n\n1\n\n1\n\n\n\n8->1\n\n\n\n\n\n12\n\n12\n\n\n\n12->8\n\n\n\n\n\n12->0\n\n\n\n\n\n12->1\n\n\n\n\n\n9\n\n9\n\n\n\n12->9\n\n\n\n\n\n2\n\n2\n\n\n\n12->2\n\n\n\n\n\n3\n\n3\n\n\n\n12->3\n\n\n\n\n\n9->2\n\n\n\n\n\n9->3\n\n\n\n\n\n10\n\n10\n\n\n\n4\n\n4\n\n\n\n10->4\n\n\n\n\n\n5\n\n5\n\n\n\n10->5\n\n\n\n\n\n13\n\n13\n\n\n\n13->10\n\n\n\n\n\n13->4\n\n\n\n\n\n13->5\n\n\n\n\n\n11\n\n11\n\n\n\n13->11\n\n\n\n\n\n6\n\n6\n\n\n\n13->6\n\n\n\n\n\n7\n\n7\n\n\n\n13->7\n\n\n\n\n\n11->6\n\n\n\n\n\n11->7\n\n\n\n\n\n14\n\n14\n\n\n\n14->8\n\n\n\n\n\n14->12\n\n\n\n\n\n14->9\n\n\n\n\n\n14->10\n\n\n\n\n\n14->13\n\n\n\n\n\n14->11\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c615b0>" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define DiscreteKey pairs.\n", + "keys = [(j, 2) for j in range(15)]\n", + "\n", + "# Create thin-tree Bayesnet.\n", + "bayesNet = DiscreteBayesNet()\n", + "\n", + "\n", + "bayesNet.add(keys[0], P(keys[8], keys[12]), \"2/3 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[1], P(keys[8], keys[12]), \"4/1 2/3 3/2 1/4\")\n", + "bayesNet.add(keys[2], P(keys[9], keys[12]), \"1/4 8/2 2/3 4/1\")\n", + "bayesNet.add(keys[3], P(keys[9], keys[12]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[4], P(keys[10], keys[13]), \"2/3 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[5], P(keys[10], keys[13]), \"4/1 2/3 3/2 1/4\")\n", + "bayesNet.add(keys[6], P(keys[11], keys[13]), \"1/4 3/2 2/3 4/1\")\n", + "bayesNet.add(keys[7], P(keys[11], keys[13]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[8], P(keys[12], keys[14]), \"T 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[9], P(keys[12], keys[14]), \"4/1 2/3 F 1/4\")\n", + "bayesNet.add(keys[10], P(keys[13], keys[14]), \"1/4 3/2 2/3 4/1\")\n", + "bayesNet.add(keys[11], P(keys[13], keys[14]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[12], P(keys[14]), \"3/1 3/1\")\n", + "bayesNet.add(keys[13], P(keys[14]), \"1/3 3/1\")\n", + "\n", + "bayesNet.add(keys[14], P(), \"1/3\")\n", + "\n", + "show(bayesNet)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DiscreteValues{0: 1, 1: 1, 2: 0, 3: 1, 4: 1, 5: 1, 6: 0, 7: 1, 8: 0, 9: 0, 10: 0, 11: 0, 12: 1, 13: 1, 14: 0}\n", + "DiscreteValues{0: 0, 1: 1, 2: 0, 3: 0, 4: 1, 5: 0, 6: 0, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}\n", + "DiscreteValues{0: 1, 1: 0, 2: 1, 3: 1, 4: 0, 5: 0, 6: 1, 7: 0, 8: 1, 9: 0, 10: 1, 11: 1, 12: 0, 13: 1, 14: 0}\n", + "DiscreteValues{0: 1, 1: 1, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 1, 8: 0, 9: 1, 10: 0, 11: 0, 12: 1, 13: 0, 14: 1}\n", + "DiscreteValues{0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 1, 6: 1, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}\n" + ] + } + ], + "source": [ + "# Sample Bayes net (needs conditionals added in elimination order!)\n", + "for i in range(5):\n", + " print(bayesNet.sample())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\n\nvar0\n\n0\n\n\n\nfactor0\n\n\n\n\nvar0--factor0\n\n\n\n\nvar1\n\n1\n\n\n\nfactor1\n\n\n\n\nvar1--factor1\n\n\n\n\nvar2\n\n2\n\n\n\nfactor2\n\n\n\n\nvar2--factor2\n\n\n\n\nvar3\n\n3\n\n\n\nfactor3\n\n\n\n\nvar3--factor3\n\n\n\n\nvar4\n\n4\n\n\n\nfactor4\n\n\n\n\nvar4--factor4\n\n\n\n\nvar5\n\n5\n\n\n\nfactor5\n\n\n\n\nvar5--factor5\n\n\n\n\nvar6\n\n6\n\n\n\nfactor6\n\n\n\n\nvar6--factor6\n\n\n\n\nvar7\n\n7\n\n\n\nfactor7\n\n\n\n\nvar7--factor7\n\n\n\n\nvar8\n\n8\n\n\n\nvar8--factor0\n\n\n\n\nvar8--factor1\n\n\n\n\nfactor8\n\n\n\n\nvar8--factor8\n\n\n\n\nvar9\n\n9\n\n\n\nvar9--factor2\n\n\n\n\nvar9--factor3\n\n\n\n\nfactor9\n\n\n\n\nvar9--factor9\n\n\n\n\nvar10\n\n10\n\n\n\nvar10--factor4\n\n\n\n\nvar10--factor5\n\n\n\n\nfactor10\n\n\n\n\nvar10--factor10\n\n\n\n\nvar11\n\n11\n\n\n\nvar11--factor6\n\n\n\n\nvar11--factor7\n\n\n\n\nfactor11\n\n\n\n\nvar11--factor11\n\n\n\n\nvar12\n\n12\n\n\n\nvar14\n\n14\n\n\n\nvar12--var14\n\n\n\n\nvar12--factor0\n\n\n\n\nvar12--factor1\n\n\n\n\nvar12--factor2\n\n\n\n\nvar12--factor3\n\n\n\n\nvar12--factor8\n\n\n\n\nvar12--factor9\n\n\n\n\nvar13\n\n13\n\n\n\nvar13--var14\n\n\n\n\nvar13--factor4\n\n\n\n\nvar13--factor5\n\n\n\n\nvar13--factor6\n\n\n\n\nvar13--factor7\n\n\n\n\nvar13--factor10\n\n\n\n\nvar13--factor11\n\n\n\n\nvar14--factor8\n\n\n\n\nvar14--factor9\n\n\n\n\nvar14--factor10\n\n\n\n\nvar14--factor11\n\n\n\n\nfactor14\n\n\n\n\nvar14--factor14\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c61f10>" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a factor graph out of the Bayes net.\n", + "factorGraph = DiscreteFactorGraph(bayesNet)\n", + "show(factorGraph)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n0\n\n8,12,14\n\n\n\n1\n\n0 : 8,12\n\n\n\n0->1\n\n\n\n\n\n2\n\n1 : 8,12\n\n\n\n0->2\n\n\n\n\n\n3\n\n9 : 12,14\n\n\n\n0->3\n\n\n\n\n\n6\n\n10,13 : 14\n\n\n\n0->6\n\n\n\n\n\n4\n\n2 : 9,12\n\n\n\n3->4\n\n\n\n\n\n5\n\n3 : 9,12\n\n\n\n3->5\n\n\n\n\n\n7\n\n4 : 10,13\n\n\n\n6->7\n\n\n\n\n\n8\n\n5 : 10,13\n\n\n\n6->8\n\n\n\n\n\n9\n\n11 : 13,14\n\n\n\n6->9\n\n\n\n\n\n10\n\n6 : 11,13\n\n\n\n9->10\n\n\n\n\n\n11\n\n7 : 11,13\n\n\n\n9->11\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c61b50>" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a BayesTree out of the factor graph.\n", + "ordering = Ordering()\n", + "for j in range(15): ordering.push_back(j)\n", + "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", + "show(bayesTree)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3.8.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/gtsam/notebooks/DiscreteSwitching.ipynb b/python/gtsam/notebooks/DiscreteSwitching.ipynb new file mode 100644 index 0000000000..6872e78c80 --- /dev/null +++ b/python/gtsam/notebooks/DiscreteSwitching.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A Discrete Switching System\n", + "\n", + "A la MHS, but all discrete.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gtsam import DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering\n", + "from gtsam.symbol_shorthand import S\n", + "from gtsam.symbol_shorthand import M\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def P(*args):\n", + " \"\"\" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.\"\"\"\n", + " # TODO: We can make life easier by providing variable argument functions in C++ itself.\n", + " dks = DiscreteKeys()\n", + " for key in args:\n", + " dks.push_back(key)\n", + " return dks\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "\n", + "\n", + "class show(graphviz.Source):\n", + " \"\"\" Display an object with a dot method as a graph.\"\"\"\n", + "\n", + " def __init__(self, obj):\n", + " \"\"\"Construct from object with 'dot' method.\"\"\"\n", + " # This small class takes an object, calls its dot function, and uses the\n", + " # resulting string to initialize a graphviz.Source instance. This in turn\n", + " # has a _repr_mimebundle_ method, which then renders it in the notebook.\n", + " super().__init__(obj.dot())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nrStates = 3\n", + "K = 5\n", + "\n", + "bayesNet = DiscreteBayesNet()\n", + "for k in range(1, K):\n", + " key = S(k), nrStates\n", + " key_plus = S(k+1), nrStates\n", + " mode = M(k), 2\n", + " bayesNet.add(key_plus, P(mode, key), \"9/1/0 1/8/1 0/1/9 1/9/0 0/1/9 9/0/1\")\n", + "\n", + "bayesNet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show(bayesNet)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a factor graph out of the Bayes net.\n", + "factorGraph = DiscreteFactorGraph(bayesNet)\n", + "show(factorGraph)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a BayesTree out of the factor graph.\n", + "ordering = Ordering()\n", + "# First eliminate \"continuous\" states in time order\n", + "for k in range(1, K+1):\n", + " ordering.push_back(S(k))\n", + "for k in range(1, K):\n", + " ordering.push_back(M(k))\n", + "print(ordering)\n", + "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", + "bayesTree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show(bayesTree)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3.8.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/gtsam/preamble/discrete.h b/python/gtsam/preamble/discrete.h new file mode 100644 index 0000000000..608508c32f --- /dev/null +++ b/python/gtsam/preamble/discrete.h @@ -0,0 +1,16 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +#include + +PYBIND11_MAKE_OPAQUE(gtsam::DiscreteKeys); diff --git a/python/gtsam/specializations/discrete.h b/python/gtsam/specializations/discrete.h new file mode 100644 index 0000000000..458a2ea4c0 --- /dev/null +++ b/python/gtsam/specializations/discrete.h @@ -0,0 +1,17 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `py::bind_vector` and similar machinery gives the std container a Python-like + * interface, but without the `` copying mechanism. Combined + * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, + * and saves one copy operation. + */ + +// Seems this is not a good idea with inherited stl +//py::bind_vector>(m_, "DiscreteKeys"); + +py::bind_map(m_, "DiscreteValues"); diff --git a/python/gtsam/tests/testEssentialMatrixConstraint.py b/python/gtsam/tests/testEssentialMatrixConstraint.py new file mode 100644 index 0000000000..8439ad2e93 --- /dev/null +++ b/python/gtsam/tests/testEssentialMatrixConstraint.py @@ -0,0 +1,47 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +visual_isam unit tests. +Author: Frank Dellaert & Pablo Alcantarilla +""" + +import unittest + +import gtsam +import numpy as np +from gtsam import (EssentialMatrix, EssentialMatrixConstraint, Point3, Pose3, + Rot3, Unit3, symbol) +from gtsam.utils.test_case import GtsamTestCase + + +class TestVisualISAMExample(GtsamTestCase): + def test_VisualISAMExample(self): + + # Create a factor + poseKey1 = symbol('x', 1) + poseKey2 = symbol('x', 2) + trueRotation = Rot3.RzRyRx(0.15, 0.15, -0.20) + trueTranslation = Point3(+0.5, -1.0, +1.0) + trueDirection = Unit3(trueTranslation) + E = EssentialMatrix(trueRotation, trueDirection) + model = gtsam.noiseModel.Isotropic.Sigma(5, 0.25) + factor = EssentialMatrixConstraint(poseKey1, poseKey2, E, model) + + # Create a linearization point at the zero-error point + pose1 = Pose3(Rot3.RzRyRx(0.00, -0.15, 0.30), Point3(-4.0, 7.0, -10.0)) + pose2 = Pose3( + Rot3.RzRyRx(0.179693265735950, 0.002945368776519, + 0.102274823253840), + Point3(-3.37493895, 6.14660244, -8.93650986)) + + expected = np.zeros((5, 1)) + actual = factor.evaluateError(pose1, pose2) + self.gtsamAssertEquals(actual, expected, 1e-8) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py new file mode 100644 index 0000000000..12a60d5cb1 --- /dev/null +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -0,0 +1,54 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for DecisionTreeFactors. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys +from gtsam.utils.test_case import GtsamTestCase + + +class TestDecisionTreeFactor(GtsamTestCase): + """Tests for DecisionTreeFactors.""" + + def setUp(self): + A = (12, 3) + B = (5, 2) + self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6") + + def test_enumerate(self): + actual = self.factor.enumerate() + _, values = zip(*actual) + self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def test_markdown(self): + """Test whether the _repr_markdown_ method.""" + + expected = \ + "|A|B|value|\n" \ + "|:-:|:-:|:-:|\n" \ + "|0|0|1|\n" \ + "|0|1|2|\n" \ + "|1|0|3|\n" \ + "|1|1|4|\n" \ + "|2|0|5|\n" \ + "|2|1|6|\n" + + def formatter(x: int): + return "A" if x == 12 else "B" + + actual = self.factor._repr_markdown_(formatter) + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py new file mode 100644 index 0000000000..bdd5a05464 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -0,0 +1,112 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes Nets. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, + DiscreteKeys, DiscretePrior, DiscreteValues, Ordering) +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_constructor(self): + """Test constructing a Bayes net.""" + + bayesNet = DiscreteBayesNet() + Parent, Child = (0, 2), (1, 2) + empty = DiscreteKeys() + prior = DiscreteConditional(Parent, empty, "6/4") + bayesNet.add(prior) + + parents = DiscreteKeys() + parents.push_back(Parent) + conditional = DiscreteConditional(Child, parents, "7/3 8/2") + bayesNet.add(conditional) + + # Check conversion to factor graph: + fg = DiscreteFactorGraph(bayesNet) + self.assertEqual(fg.size(), 2) + self.assertEqual(fg.at(1).size(), 2) + + def test_Asia(self): + """Test full Asia example.""" + + Asia = (0, 2) + Smoking = (4, 2) + Tuberculosis = (3, 2) + LungCancer = (6, 2) + + Bronchitis = (7, 2) + Either = (5, 2) + XRay = (2, 2) + Dyspnea = (1, 2) + + asia = DiscreteBayesNet() + asia.add(Asia, "99/1") + asia.add(Smoking, "50/50") + + asia.add(Tuberculosis, [Asia], "99/1 95/5") + asia.add(LungCancer, [Smoking], "99/1 90/10") + asia.add(Bronchitis, [Smoking], "70/30 40/60") + + asia.add(Either, [Tuberculosis, LungCancer], "F T T T") + + asia.add(XRay, [Either], "95/5 2/98") + asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9") + + # Convert to factor graph + fg = DiscreteFactorGraph(asia) + + # Create solver and eliminate + ordering = Ordering() + for j in range(8): + ordering.push_back(j) + chordal = fg.eliminateSequential(ordering) + expected2 = DiscretePrior(Bronchitis, "11/9") + self.gtsamAssertEquals(chordal.at(7), expected2) + + # solve + actualMPE = chordal.optimize() + expectedMPE = DiscreteValues() + for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: + expectedMPE[key[0]] = 0 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) + + # Check value for MPE is the same + self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) + + # add evidence, we were in Asia and we have dyspnea + fg.add(Asia, "0 1") + fg.add(Dyspnea, "0 1") + + # solve again, now with evidence + chordal2 = fg.eliminateSequential(ordering) + actualMPE2 = chordal2.optimize() + expectedMPE2 = DiscreteValues() + for key in [XRay, Tuberculosis, Either, LungCancer]: + expectedMPE2[key[0]] = 0 + for key in [Asia, Dyspnea, Smoking, Bronchitis]: + expectedMPE2[key[0]] = 1 + self.assertEqual(list(actualMPE2.items()), + list(expectedMPE2.items())) + + # now sample from it + actualSample = chordal2.sample() + self.assertEqual(len(actualSample), 8) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteBayesTree.dot b/python/gtsam/tests/test_DiscreteBayesTree.dot new file mode 100644 index 0000000000..d7cf7d9bc0 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesTree.dot @@ -0,0 +1,25 @@ +digraph G{ +0[label="8,12,14"]; +0->1 +1[label="0 : 8,12"]; +0->2 +2[label="1 : 8,12"]; +0->3 +3[label="9 : 12,14"]; +3->4 +4[label="2 : 9,12"]; +3->5 +5[label="3 : 9,12"]; +0->6 +6[label="10,13 : 14"]; +6->7 +7[label="4 : 10,13"]; +6->8 +8[label="5 : 10,13"]; +6->9 +9[label="11 : 13,14"]; +9->10 +10[label="6 : 11,13"]; +9->11 +11[label="7 : 11,13"]; +} \ No newline at end of file diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py new file mode 100644 index 0000000000..b1ed4fe696 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -0,0 +1,79 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes trees. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, + DiscreteConditional, DiscreteFactorGraph, Ordering) +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_elimination(self): + """Test Multifrontal elimination.""" + + # Define DiscreteKey pairs. + keys = [(j, 2) for j in range(15)] + + # Create thin-tree Bayesnet. + bayesNet = DiscreteBayesNet() + + bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1") + bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1") + bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4") + bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[12], [keys[14]], "3/1 3/1") + bayesNet.add(keys[13], [keys[14]], "1/3 3/1") + + bayesNet.add(keys[14], "1/3") + + # Create a factor graph out of the Bayes net. + factorGraph = DiscreteFactorGraph(bayesNet) + + # Create a BayesTree out of the factor graph. + ordering = Ordering() + for j in range(15): + ordering.push_back(j) + bayesTree = factorGraph.eliminateMultifrontal(ordering) + + # Uncomment these for visualization: + # print(bayesTree) + # for key in range(15): + # bayesTree[key].printSignature() + # bayesTree.saveGraph("test_DiscreteBayesTree.dot") + + self.assertFalse(bayesTree.empty()) + self.assertEqual(12, bayesTree.size()) + + # The root is P( 8 12 14), we can retrieve it by key: + root = bayesTree[8] + self.assertIsInstance(root, DiscreteBayesTreeClique) + self.assertTrue(root.isRoot()) + self.assertIsInstance(root.conditional(), DiscreteConditional) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py new file mode 100644 index 0000000000..1b2ce70cd7 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -0,0 +1,71 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Conditionals. +Author: Varun Agrawal +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteConditional(GtsamTestCase): + """Tests for Discrete Conditionals.""" + + def test_single_value_versions(self): + X = (0, 2) + Y = (1, 3) + conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") + + actual0 = conditional.likelihood(0) + expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5") + self.gtsamAssertEquals(actual0, expected0, 1e-9) + + actual1 = conditional.likelihood(1) + expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") + self.gtsamAssertEquals(actual1, expected1, 1e-9) + + actual = conditional.sample(2) + self.assertIsInstance(actual, int) + + def test_markdown(self): + """Test whether the _repr_markdown_ method.""" + + A = (2, 2) + B = (1, 2) + C = (0, 3) + parents = DiscreteKeys() + parents.push_back(B) + parents.push_back(C) + + conditional = DiscreteConditional(A, parents, + "0/1 1/3 1/1 3/1 0/1 1/0") + expected = \ + " *P(A|B,C)*:\n\n" \ + "|B|C|0|1|\n" \ + "|:-:|:-:|:-:|:-:|\n" \ + "|0|0|0|1|\n" \ + "|0|1|0.25|0.75|\n" \ + "|0|2|0.5|0.5|\n" \ + "|1|0|0.75|0.25|\n" \ + "|1|1|0|1|\n" \ + "|1|2|1|0|\n" + + def formatter(x: int): + names = ["C", "B", "A"] + return names[x] + + actual = conditional._repr_markdown_(formatter) + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py new file mode 100644 index 0000000000..1ba145e096 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -0,0 +1,122 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Factor Graphs. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteFactorGraph(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_evaluation(self): + """Test constructing and evaluating a discrete factor graph.""" + + # Three keys + P1 = (0, 2) + P2 = (1, 2) + P3 = (2, 3) + + # Create the DiscreteFactorGraph + graph = DiscreteFactorGraph() + + # Add two unary factors (priors) + graph.add(P1, [0.9, 0.3]) + graph.add(P2, "0.9 0.6") + + # Add a binary factor + graph.add([P1, P2], "4 1 10 4") + + # Instantiate Values + assignment = DiscreteValues() + assignment[0] = 1 + assignment[1] = 1 + + # Check if graph evaluation works ( 0.3*0.6*4 ) + self.assertAlmostEqual(.72, graph(assignment)) + + # Create a new test with third node and adding unary and ternary factor + graph.add(P3, "0.9 0.2 0.5") + keys = DiscreteKeys() + keys.push_back(P1) + keys.push_back(P2) + keys.push_back(P3) + graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12") + + # Below assignment selects the 8th index in the ternary factor table + assignment[0] = 1 + assignment[1] = 0 + assignment[2] = 1 + + # Check if graph evaluation works (0.3*0.9*1*0.2*8) + self.assertAlmostEqual(4.32, graph(assignment)) + + # Below assignment selects the 3rd index in the ternary factor table + assignment[0] = 0 + assignment[1] = 1 + assignment[2] = 0 + + # Check if graph evaluation works (0.9*0.6*1*0.9*4) + self.assertAlmostEqual(1.944, graph(assignment)) + + # Check if graph product works + product = graph.product() + self.assertAlmostEqual(1.944, product(assignment)) + + def test_optimize(self): + """Test constructing and optizing a discrete factor graph.""" + + # Three keys + C = (0, 2) + B = (1, 2) + A = (2, 2) + + # A simple factor graph (A)-fAC-(C)-fBC-(B) + # with smoothness priors + graph = DiscreteFactorGraph() + graph.add([A, C], "3 1 1 3") + graph.add([C, B], "3 1 1 3") + + # Test optimization + expectedValues = DiscreteValues() + expectedValues[0] = 0 + expectedValues[1] = 0 + expectedValues[2] = 0 + actualValues = graph.optimize() + self.assertEqual(list(actualValues.items()), + list(expectedValues.items())) + + def test_MPE(self): + """Test maximum probable explanation (MPE): same as optimize.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") + + actualMPE = graph.optimize() + + expectedMPE = DiscreteValues() + expectedMPE[0] = 0 + expectedMPE[1] = 1 + expectedMPE[2] = 1 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py new file mode 100644 index 0000000000..4f017d66a4 --- /dev/null +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -0,0 +1,60 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Priors. +Author: Varun Agrawal +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +import numpy as np +from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior +from gtsam.utils.test_case import GtsamTestCase + +X = 0, 2 + + +class TestDiscretePrior(GtsamTestCase): + """Tests for Discrete Priors.""" + + def test_constructor(self): + """Test various constructors.""" + actual = DiscretePrior(X, "2/3") + keys = DiscreteKeys() + keys.push_back(X) + f = DecisionTreeFactor(keys, "0.4 0.6") + expected = DiscretePrior(f) + self.gtsamAssertEquals(actual, expected) + + def test_operator(self): + prior = DiscretePrior(X, "2/3") + self.assertAlmostEqual(prior(0), 0.4) + self.assertAlmostEqual(prior(1), 0.6) + + def test_pmf(self): + prior = DiscretePrior(X, "2/3") + expected = np.array([0.4, 0.6]) + np.testing.assert_allclose(expected, prior.pmf()) + + def test_markdown(self): + """Test the _repr_markdown_ method.""" + + prior = DiscretePrior(X, "2/3") + expected = " *P(0)*:\n\n" \ + "|0|value|\n" \ + "|:-:|:-:|\n" \ + "|0|0.4|\n" \ + "|1|0.6|\n" \ + + actual = prior._repr_markdown_() + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam_unstable/gtsam_unstable.tpl b/python/gtsam_unstable/gtsam_unstable.tpl index aa7ac6bdb8..055fcaea78 100644 --- a/python/gtsam_unstable/gtsam_unstable.tpl +++ b/python/gtsam_unstable/gtsam_unstable.tpl @@ -40,7 +40,7 @@ PYBIND11_MODULE({module_name}, m_) {{ {wrapped_namespace} -#include "python/gtsam_unstable/specializations.h" +#include "python/gtsam_unstable/specializations/gtsam_unstable.h" }} diff --git a/python/gtsam_unstable/specializations.h b/python/gtsam_unstable/specializations/gtsam_unstable.h similarity index 100% rename from python/gtsam_unstable/specializations.h rename to python/gtsam_unstable/specializations/gtsam_unstable.h diff --git a/tests/testNonlinearFactorGraph.cpp b/tests/testNonlinearFactorGraph.cpp index fdb080a63b..4dec08f45c 100644 --- a/tests/testNonlinearFactorGraph.cpp +++ b/tests/testNonlinearFactorGraph.cpp @@ -15,6 +15,7 @@ * @brief testNonlinearFactorGraph * @author Carlos Nieto * @author Christian Potthast + * @author Frank Dellaert */ #include @@ -285,6 +286,7 @@ TEST(testNonlinearFactorGraph, addPrior) { EXPECT(0 != graph.error(values)); } +/* ************************************************************************* */ TEST(NonlinearFactorGraph, printErrors) { const NonlinearFactorGraph fg = createNonlinearFactorGraph(); @@ -309,6 +311,53 @@ TEST(NonlinearFactorGraph, printErrors) for (bool visit : visited) EXPECT(visit==true); } +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, dot) { + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var7782220156096217089[label=\"l1\"];\n" + " var8646911284551352321[label=\"x1\"];\n" + " var8646911284551352322[label=\"x2\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " var8646911284551352321--factor0;\n" + " var8646911284551352321--var8646911284551352322;\n" + " var8646911284551352321--var7782220156096217089;\n" + " var8646911284551352322--var7782220156096217089;\n" + "}\n"; + + const NonlinearFactorGraph fg = createNonlinearFactorGraph(); + string actual = fg.dot(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, dot_extra) { + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n" + " var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n" + " var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " var8646911284551352321--factor0;\n" + " var8646911284551352321--var8646911284551352322;\n" + " var8646911284551352321--var7782220156096217089;\n" + " var8646911284551352322--var7782220156096217089;\n" + "}\n"; + + const NonlinearFactorGraph fg = createNonlinearFactorGraph(); + const Values c = createValues(); + + stringstream ss; + fg.dot(ss, c); + EXPECT(ss.str() == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ diff --git a/wrap/cmake/MatlabWrap.cmake b/wrap/cmake/MatlabWrap.cmake index 083b88566f..3cb0581028 100644 --- a/wrap/cmake/MatlabWrap.cmake +++ b/wrap/cmake/MatlabWrap.cmake @@ -62,10 +62,10 @@ macro(find_and_configure_matlab) endmacro() # Consistent and user-friendly wrap function -function(matlab_wrap interfaceHeader linkLibraries +function(matlab_wrap interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags ignore_classes) find_and_configure_matlab() - wrap_and_install_library("${interfaceHeader}" "${linkLibraries}" + wrap_and_install_library("${interfaceHeader}" "${moduleName}" "${linkLibraries}" "${extraIncludeDirs}" "${extraMexFlags}" "${ignore_classes}") endfunction() @@ -77,6 +77,7 @@ endfunction() # Arguments: # # interfaceHeader: The relative path to the wrapper interface definition file. +# moduleName: The name of the wrapped module, e.g. gtsam # linkLibraries: Any *additional* libraries to link. Your project library # (e.g. `lba`), libraries it depends on, and any necessary MATLAB libraries will # be linked automatically. So normally, leave this empty. @@ -85,15 +86,15 @@ endfunction() # extraMexFlags: Any *additional* flags to pass to the compiler when building # the wrap code. Normally, leave this empty. # ignore_classes: List of classes to ignore in the wrapping. -function(wrap_and_install_library interfaceHeader linkLibraries +function(wrap_and_install_library interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags ignore_classes) - wrap_library_internal("${interfaceHeader}" "${linkLibraries}" + wrap_library_internal("${interfaceHeader}" "${moduleName}" "${linkLibraries}" "${extraIncludeDirs}" "${mexFlags}") - install_wrapped_library_internal("${interfaceHeader}") + install_wrapped_library_internal("${moduleName}") endfunction() # Internal function that wraps a library and compiles the wrapper -function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs +function(wrap_library_internal interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags) if(UNIX AND NOT APPLE) if(CMAKE_SIZEOF_VOID_P EQUAL 8) @@ -120,7 +121,6 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs # Extract module name from interface header file name get_filename_component(interfaceHeader "${interfaceHeader}" ABSOLUTE) get_filename_component(modulePath "${interfaceHeader}" PATH) - get_filename_component(moduleName "${interfaceHeader}" NAME_WE) # Paths for generated files set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}") @@ -136,8 +136,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs # explicit link libraries list so that the next block of code can unpack any # static libraries set(automaticDependencies "") - foreach(lib ${moduleName} ${linkLibraries}) - # message("MODULE NAME: ${moduleName}") + foreach(lib ${module} ${linkLibraries}) if(TARGET "${lib}") get_target_property(dependentLibraries ${lib} INTERFACE_LINK_LIBRARIES) # message("DEPENDENT LIBRARIES: ${dependentLibraries}") @@ -176,7 +175,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs set(otherLibraryTargets "") set(otherLibraryNontargets "") set(otherSourcesAndObjects "") - foreach(lib ${moduleName} ${linkLibraries} ${automaticDependencies}) + foreach(lib ${module} ${linkLibraries} ${automaticDependencies}) if(TARGET "${lib}") if(WRAP_MEX_BUILD_STATIC_MODULE) get_target_property(target_sources ${lib} SOURCES) @@ -250,7 +249,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" - ${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src ${interfaceHeader} + ${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src "${interfaceHeader}" --module_name ${moduleName} --out ${generated_files_path} --top_module_namespaces ${moduleName} --ignore ${ignore_classes} VERBATIM @@ -324,8 +323,8 @@ endfunction() # Internal function that installs a wrap toolbox function(install_wrapped_library_internal interfaceHeader) - get_filename_component(moduleName "${interfaceHeader}" NAME_WE) - set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}") + get_filename_component(module "${interfaceHeader}" NAME_WE) + set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${module}") # NOTE: only installs .m and mex binary files (not .cpp) - the trailing slash # on the directory name here prevents creating the top-level module name diff --git a/wrap/cmake/PybindWrap.cmake b/wrap/cmake/PybindWrap.cmake index 2149c7195e..2008bf2ddf 100644 --- a/wrap/cmake/PybindWrap.cmake +++ b/wrap/cmake/PybindWrap.cmake @@ -55,15 +55,44 @@ function( set(GTWRAP_PATH_SEPARATOR ";") endif() + # Create a copy of interface_headers so we can freely manipulate it + set(interface_files ${interface_headers}) + + # Pop the main interface file so that interface_files has only submodules. + list(POP_FRONT interface_files main_interface) + # Convert .i file names to .cpp file names. - foreach(filepath ${interface_headers}) - get_filename_component(interface ${filepath} NAME) - string(REPLACE ".i" ".cpp" cpp_file ${interface}) + foreach(interface_file ${interface_files}) + # This block gets the interface file name and does the replacement + get_filename_component(interface ${interface_file} NAME_WLE) + set(cpp_file "${interface}.cpp") list(APPEND cpp_files ${cpp_file}) + + # Wrap the specific interface header + # This is done so that we can create CMake dependencies in such a way so that when changing a single .i file, + # the others don't need to be regenerated. + # NOTE: We have to use `add_custom_command` so set the dependencies correctly. + # https://stackoverflow.com/questions/40032593/cmake-does-not-rebuild-dependent-after-prerequisite-changes + add_custom_command( + OUTPUT ${cpp_file} + COMMAND + ${CMAKE_COMMAND} -E env + "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" + ${PYTHON_EXECUTABLE} ${PYBIND_WRAP_SCRIPT} --src "${interface_file}" + --out "${cpp_file}" --module_name ${module_name} + --top_module_namespaces "${top_namespace}" --ignore ${ignore_classes} + --template ${module_template} --is_submodule ${_WRAP_BOOST_ARG} + DEPENDS "${interface_file}" ${module_template} "${module_name}/specializations/${interface}.h" "${module_name}/preamble/${interface}.h" + VERBATIM) + endforeach() + get_filename_component(main_interface_name ${main_interface} NAME_WLE) + set(main_cpp_file "${main_interface_name}.cpp") + list(PREPEND cpp_files ${main_cpp_file}) + add_custom_command( - OUTPUT ${cpp_files} + OUTPUT ${main_cpp_file} COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" @@ -71,23 +100,10 @@ function( --out "${generated_cpp}" --module_name ${module_name} --top_module_namespaces "${top_namespace}" --ignore ${ignore_classes} --template ${module_template} ${_WRAP_BOOST_ARG} - DEPENDS "${interface_headers}" ${module_template} + DEPENDS "${main_interface}" ${module_template} "${module_name}/specializations/${main_interface_name}.h" "${module_name}/specializations/${main_interface_name}.h" VERBATIM) - add_custom_target(pybind_wrap_${module_name} ALL DEPENDS ${cpp_files}) - - # Late dependency injection, to make sure this gets called whenever the - # interface header or the wrap library are updated. - # ~~~ - # See: https://stackoverflow.com/questions/40032593/cmake-does-not-rebuild-dependent-after-prerequisite-changes - # ~~~ - add_custom_command( - OUTPUT ${cpp_files} - DEPENDS ${interface_headers} - # @GTWRAP_SOURCE_DIR@/gtwrap/interface_parser.py - # @GTWRAP_SOURCE_DIR@/gtwrap/pybind_wrapper.py - # @GTWRAP_SOURCE_DIR@/gtwrap/template_instantiator.py - APPEND) + add_custom_target(pybind_wrap_${module_name} DEPENDS ${cpp_files}) pybind11_add_module(${target} "${cpp_files}") diff --git a/wrap/gtwrap/interface_parser/type.py b/wrap/gtwrap/interface_parser/type.py index e94db4ff2d..7aacf0b81a 100644 --- a/wrap/gtwrap/interface_parser/type.py +++ b/wrap/gtwrap/interface_parser/type.py @@ -53,6 +53,10 @@ def __init__(self, self.name = t[-1] # the name is the last element in this list self.namespaces = t[:-1] + # If the first namespace is empty string, just get rid of it. + if self.namespaces and self.namespaces[0] == '': + self.namespaces.pop(0) + if instantiations: if isinstance(instantiations, Sequence): self.instantiations = instantiations # type: ignore @@ -92,8 +96,8 @@ def to_cpp(self) -> str: else: cpp_name = self.name return '{}{}{}'.format( - "::".join(self.namespaces[idx:]), - "::" if self.namespaces[idx:] else "", + "::".join(self.namespaces), + "::" if self.namespaces else "", cpp_name, ) diff --git a/wrap/gtwrap/matlab_wrapper/mixins.py b/wrap/gtwrap/matlab_wrapper/mixins.py index 2d7c75b397..f4a7988fda 100644 --- a/wrap/gtwrap/matlab_wrapper/mixins.py +++ b/wrap/gtwrap/matlab_wrapper/mixins.py @@ -108,7 +108,7 @@ def _format_type_name(self, elif is_method: formatted_type_name += self.data_type_param.get(name) or name else: - formatted_type_name += name + formatted_type_name += str(name) if separator == "::": # C++ templates = [] @@ -192,10 +192,9 @@ def _format_static_method(self, method = '' if isinstance(static_method, parser.StaticMethod): - method += "".join([separator + x for x in static_method.parent.namespaces()]) + \ - separator + static_method.parent.name + separator + method += static_method.parent.to_cpp() + separator - return method[2 * len(separator):] + return method def _format_global_function(self, function: Union[parser.GlobalFunction, Any], diff --git a/wrap/gtwrap/matlab_wrapper/templates.py b/wrap/gtwrap/matlab_wrapper/templates.py index 7aaf8f487b..3d1306dca9 100644 --- a/wrap/gtwrap/matlab_wrapper/templates.py +++ b/wrap/gtwrap/matlab_wrapper/templates.py @@ -66,7 +66,7 @@ class WrapperTemplate: mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) {{ + if(mexPutVariable("global", "gtsam_{module_name}_rttiRegistry_created", newAlreadyCreated) != 0) {{ mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); }} mxDestroyArray(newAlreadyCreated); diff --git a/wrap/gtwrap/matlab_wrapper/wrapper.py b/wrap/gtwrap/matlab_wrapper/wrapper.py index 97945f73a2..42610999df 100755 --- a/wrap/gtwrap/matlab_wrapper/wrapper.py +++ b/wrap/gtwrap/matlab_wrapper/wrapper.py @@ -5,6 +5,7 @@ # pylint: disable=too-many-lines, no-self-use, too-many-arguments, too-many-branches, too-many-statements +import copy import os import os.path as osp import textwrap @@ -13,6 +14,7 @@ import gtwrap.interface_parser as parser import gtwrap.template_instantiator as instantiator +from gtwrap.interface_parser.function import ArgumentList from gtwrap.matlab_wrapper.mixins import CheckMixin, FormatMixin from gtwrap.matlab_wrapper.templates import WrapperTemplate @@ -137,6 +139,37 @@ def _insert_spaces(self, x, y): """ return x + '\n' + ('' if y == '' else ' ') + y + @staticmethod + def _expand_default_arguments(method, save_backup=True): + """Recursively expand all possibilities for optional default arguments. + We create "overload" functions with fewer arguments, but since we have to "remember" what + the default arguments are for later, we make a backup. + """ + def args_copy(args): + return ArgumentList([copy.copy(arg) for arg in args.list()]) + def method_copy(method): + method2 = copy.copy(method) + method2.args = args_copy(method.args) + method2.args.backup = method.args.backup + return method2 + if save_backup: + method.args.backup = args_copy(method.args) + method = method_copy(method) + for arg in reversed(method.args.list()): + if arg.default is not None: + arg.default = None + methodWithArg = method_copy(method) + method.args.list().remove(arg) + return [ + methodWithArg, + *MatlabWrapper._expand_default_arguments(method, save_backup=False) + ] + break + assert all(arg.default is None for arg in method.args.list()), \ + 'In parsing method {:}: Arguments with default values cannot appear before ones ' \ + 'without default values.'.format(method.name) + return [method] + def _group_methods(self, methods): """Group overloaded methods together""" method_map = {} @@ -147,9 +180,9 @@ def _group_methods(self, methods): if method_index is None: method_map[method.name] = len(method_out) - method_out.append([method]) + method_out.append(MatlabWrapper._expand_default_arguments(method)) else: - method_out[method_index].append(method) + method_out[method_index] += MatlabWrapper._expand_default_arguments(method) return method_out @@ -239,18 +272,18 @@ def _wrap_list_variable_arguments(self, args): return var_list_wrap - def _wrap_method_check_statement(self, args): + def _wrap_method_check_statement(self, args: parser.ArgumentList): """ Wrap the given arguments into either just a varargout call or a call in an if statement that checks if the parameters are accurate. + + TODO Update this method so that default arguments are supported. """ - check_statement = '' arg_id = 1 - if check_statement == '': - check_statement = \ - 'if length(varargin) == {param_count}'.format( - param_count=len(args.list())) + param_count = len(args) + check_statement = 'if length(varargin) == {param_count}'.format( + param_count=param_count) for _, arg in enumerate(args.list()): name = arg.ctype.typename.name @@ -301,13 +334,9 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): ((a), Test& t = *unwrap_shared_ptr< Test >(in[1], "ptr_Test");), ((a), std::shared_ptr p1 = unwrap_shared_ptr< Test >(in[1], "ptr_Test");) """ - params = '' body_args = '' for arg in args.list(): - if params != '': - params += ',' - if self.is_ref(arg.ctype): # and not constructor: ctype_camel = self._format_type_name(arg.ctype.typename, separator='') @@ -336,8 +365,6 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): name=arg.name, id=arg_id)), prefix=' ') - if call_type == "": - params += "*" else: body_args += textwrap.indent(textwrap.dedent('''\ @@ -347,10 +374,29 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): id=arg_id)), prefix=' ') - params += arg.name - arg_id += 1 + params = '' + explicit_arg_names = [arg.name for arg in args.list()] + # when returning the params list, we need to re-include the default args. + for arg in args.backup.list(): + if params != '': + params += ',' + + if (arg.default is not None) and (arg.name not in explicit_arg_names): + params += arg.default + continue + + if (not self.is_ref(arg.ctype)) and (self.is_shared_ptr(arg.ctype)) and (self.is_ptr( + arg.ctype)) and (arg.ctype.typename.name not in self.ignore_namespace): + if arg.ctype.is_shared_ptr: + call_type = arg.ctype.is_shared_ptr + else: + call_type = arg.ctype.is_ptr + if call_type == "": + params += "*" + params += arg.name + return params, body_args @staticmethod @@ -555,6 +601,8 @@ def wrap_class_constructors(self, namespace_name, inst_class, parent_name, if not isinstance(ctors, Iterable): ctors = [ctors] + ctors = sum((MatlabWrapper._expand_default_arguments(ctor) for ctor in ctors), []) + methods_wrap = textwrap.indent(textwrap.dedent("""\ methods function obj = {class_name}(varargin) @@ -674,20 +722,7 @@ def wrap_class_display(self): def _group_class_methods(self, methods): """Group overloaded methods together""" - method_map = {} - method_out = [] - - for method in methods: - method_index = method_map.get(method.name) - - if method_index is None: - method_map[method.name] = len(method_out) - method_out.append([method]) - else: - # print("[_group_methods] Merging {} with {}".format(method_index, method.name)) - method_out[method_index].append(method) - - return method_out + return self._group_methods(methods) @classmethod def _format_varargout(cls, return_type, return_type_formatted): @@ -809,7 +844,7 @@ def wrap_static_methods(self, namespace_name, instantiated_class, for static_method in static_methods: format_name = list(static_method[0].name) - format_name[0] = format_name[0].upper() + format_name[0] = format_name[0] if static_method[0].name in self.ignore_methods: continue @@ -850,12 +885,13 @@ def wrap_static_methods(self, namespace_name, instantiated_class, wrapper=self._wrapper_name(), id=self._update_wrapper_id( (namespace_name, instantiated_class, - static_overload.name, static_overload)), + static_overload.name, static_overload)), class_name=instantiated_class.name, end_statement=end_statement), - prefix=' ') + prefix=' ') - #TODO Figure out what is static_overload doing here. + # If the arguments don't match any of the checks above, + # throw an error with the class and method name. method_text += textwrap.indent(textwrap.dedent("""\ error('Arguments do not match any overload of function {class_name}.{method_name}'); """.format(class_name=class_name, @@ -1081,7 +1117,6 @@ def wrap_collector_function_return(self, method): obj_start = '' if isinstance(method, instantiator.InstantiatedMethod): - # method_name = method.original.name method_name = method.to_cpp() obj_start = 'obj->' @@ -1090,6 +1125,10 @@ def wrap_collector_function_return(self, method): # self._format_type_name(method.instantiations)) method = method.to_cpp() + elif isinstance(method, instantiator.InstantiatedStaticMethod): + method_name = self._format_static_method(method, '::') + method_name += method.original.name + elif isinstance(method, parser.GlobalFunction): method_name = self._format_global_function(method, '::') method_name += method.name @@ -1230,9 +1269,9 @@ def generate_collector_function(self, func_id): Collector_{class_name}::iterator item; item = collector_{class_name}.find(self); if(item != collector_{class_name}.end()) {{ - delete self; collector_{class_name}.erase(item); }} + delete self; ''').format(class_name_sep=class_name_separated, class_name=class_name), prefix=' ') @@ -1250,7 +1289,7 @@ def generate_collector_function(self, func_id): method_name = '' if is_static_method: - method_name = self._format_static_method(extra) + '.' + method_name = self._format_static_method(extra, '.') method_name += extra.name @@ -1567,23 +1606,23 @@ def generate_content(self, cc_content, path): def wrap(self, files, path): """High level function to wrap the project.""" + content = "" modules = {} for file in files: with open(file, 'r') as f: - content = f.read() + content += f.read() - # Parse the contents of the interface file - parsed_result = parser.Module.parseString(content) - # print(parsed_result) + # Parse the contents of the interface file + parsed_result = parser.Module.parseString(content) - # Instantiate the module - module = instantiator.instantiate_namespace(parsed_result) + # Instantiate the module + module = instantiator.instantiate_namespace(parsed_result) - if module.name in modules: - modules[module. - name].content[0].content += module.content[0].content - else: - modules[module.name] = module + if module.name in modules: + modules[ + module.name].content[0].content += module.content[0].content + else: + modules[module.name] = module for module in modules.values(): # Wrap the full namespace diff --git a/wrap/gtwrap/pybind_wrapper.py b/wrap/gtwrap/pybind_wrapper.py index 809c69b56e..1a3f10bf52 100755 --- a/wrap/gtwrap/pybind_wrapper.py +++ b/wrap/gtwrap/pybind_wrapper.py @@ -14,6 +14,7 @@ import re from pathlib import Path +from typing import List import gtwrap.interface_parser as parser import gtwrap.template_instantiator as instantiator @@ -46,6 +47,11 @@ def __init__(self, # amount of indentation to add before each function/method declaration. self.method_indent = '\n' + (' ' * 8) + # Special methods which are leveraged by ipython/jupyter notebooks + self._ipython_special_methods = [ + "svg", "png", "jpeg", "html", "javascript", "markdown", "latex" + ] + def _py_args_names(self, args): """Set the argument names in Pybind11 format.""" names = args.names() @@ -86,34 +92,99 @@ def wrap_ctors(self, my_class): )) return res + def _wrap_serialization(self, cpp_class): + """Helper method to add serialize, deserialize and pickle methods to the wrapped class.""" + if not cpp_class in self._serializing_classes: + self._serializing_classes.append(cpp_class) + + serialize_method = self.method_indent + \ + ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*') + + deserialize_method = self.method_indent + \ + '.def("deserialize", []({class_inst} self, string serialized)' \ + '{{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized"))' \ + .format(class_inst=cpp_class + '*') + + # Since this class supports serialization, we also add the pickle method. + pickle_method = self.method_indent + \ + ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" + + return serialize_method + deserialize_method + \ + pickle_method.format(cpp_class=cpp_class, indent=self.method_indent) + + def _wrap_print(self, ret: str, method: parser.Method, cpp_class: str, + args_names: List[str], args_signature_with_names: str, + py_args_names: str, prefix: str, suffix: str): + """ + Update the print method to print to the output stream and append a __repr__ method. + + Args: + ret (str): The result of the parser. + method (parser.Method): The method to be wrapped. + cpp_class (str): The C++ name of the class to which the method belongs. + args_names (List[str]): List of argument variable names passed to the method. + args_signature_with_names (str): C++ arguments containing their names and type signatures. + py_args_names (str): The pybind11 formatted version of the argument list. + prefix (str): Prefix to add to the wrapped method when writing to the cpp file. + suffix (str): Suffix to add to the wrapped method when writing to the cpp file. + + Returns: + str: The wrapped print method. + """ + # Redirect stdout - see pybind docs for why this is a good idea: + # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream + ret = ret.replace('self->print', + 'py::scoped_ostream_redirect output; self->print') + + # Make __repr__() call .print() internally + ret += '''{prefix}.def("__repr__", + [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ + gtsam::RedirectCout redirect; + self.{method_name}({method_args}); + return redirect.str(); + }}{py_args_names}){suffix}'''.format( + prefix=prefix, + cpp_class=cpp_class, + opt_comma=', ' if args_names else '', + args_signature_with_names=args_signature_with_names, + method_name=method.name, + method_args=", ".join(args_names) if args_names else '', + py_args_names=py_args_names, + suffix=suffix) + return ret + def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): + """ + Wrap the `method` for the class specified by `cpp_class`. + + Args: + method: The method to wrap. + cpp_class: The C++ name of the class to which the method belongs. + prefix: Prefix to add to the wrapped method when writing to the cpp file. + suffix: Suffix to add to the wrapped method when writing to the cpp file. + method_suffix: A string to append to the wrapped method name. + """ py_method = method.name + method_suffix cpp_method = method.to_cpp() + args_names = method.args.names() + py_args_names = self._py_args_names(method.args) + args_signature_with_names = self._method_args_signature(method.args) + + # Special handling for the serialize/serializable method if cpp_method in ["serialize", "serializable"]: - if not cpp_class in self._serializing_classes: - self._serializing_classes.append(cpp_class) - serialize_method = self.method_indent + \ - ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*') - deserialize_method = self.method_indent + \ - '.def("deserialize", []({class_inst} self, string serialized)' \ - '{{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized"))' \ - .format(class_inst=cpp_class + '*') - return serialize_method + deserialize_method - - if cpp_method == "pickle": - if not cpp_class in self._serializing_classes: - raise ValueError( - "Cannot pickle a class which is not serializable") - pickle_method = self.method_indent + \ - ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" - return pickle_method.format(cpp_class=cpp_class, - indent=self.method_indent) + return self._wrap_serialization(cpp_class) + + # Special handling of ipython specific methods + # https://ipython.readthedocs.io/en/stable/config/integrating.html + if cpp_method in self._ipython_special_methods: + idx = self._ipython_special_methods.index(cpp_method) + py_method = f"_repr_{self._ipython_special_methods[idx]}_" # Add underscore to disambiguate if the method name matches a python keyword if py_method in self.python_keywords: @@ -125,9 +196,6 @@ def _wrap_method(self, method, (parser.StaticMethod, instantiator.InstantiatedStaticMethod)) return_void = method.return_type.is_void() - args_names = method.args.names() - py_args_names = self._py_args_names(method.args) - args_signature_with_names = self._method_args_signature(method.args) caller = cpp_class + "::" if not is_method else "self->" function_call = ('{opt_return} {caller}{method_name}' @@ -158,27 +226,9 @@ def _wrap_method(self, # Create __repr__ override # We allow all arguments to .print() and let the compiler handle type mismatches. if method.name == 'print': - # Redirect stdout - see pybind docs for why this is a good idea: - # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream - ret = ret.replace( - 'self->print', - 'py::scoped_ostream_redirect output; self->print') - - # Make __repr__() call .print() internally - ret += '''{prefix}.def("__repr__", - [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ - gtsam::RedirectCout redirect; - self.{method_name}({method_args}); - return redirect.str(); - }}{py_args_names}){suffix}'''.format( - prefix=prefix, - cpp_class=cpp_class, - opt_comma=', ' if args_names else '', - args_signature_with_names=args_signature_with_names, - method_name=method.name, - method_args=", ".join(args_names) if args_names else '', - py_args_names=py_args_names, - suffix=suffix) + ret = self._wrap_print(ret, method, cpp_class, args_names, + args_signature_with_names, py_args_names, + prefix, suffix) return ret @@ -624,28 +674,47 @@ def wrap_file(self, content, module_name=None, submodules=None): submodules_init="\n".join(submodules_init), ) - def wrap(self, sources, main_output): + def wrap_submodule(self, source): """ - Wrap all the source interface files. + Wrap a list of submodule files, i.e. a set of interface files which are + in support of a larger wrapping project. + + E.g. This is used in GTSAM where we have a main gtsam.i, but various smaller .i files + which are the submodules. + The benefit of this scheme is that it reduces compute and memory usage during compilation. + + Args: + source: Interface file which forms the submodule. + """ + filename = Path(source).name + module_name = Path(source).stem + + # Read in the complete interface (.i) file + with open(source, "r") as f: + content = f.read() + # Wrap the read-in content + cc_content = self.wrap_file(content, module_name=module_name) + + # Generate the C++ code which Pybind11 will use. + with open(filename.replace(".i", ".cpp"), "w") as f: + f.write(cc_content) + + def wrap(self, sources, main_module_name): + """ + Wrap all the main interface file. Args: sources: List of all interface files. - main_output: The name for the main module. + The first file should be the main module. + main_module_name: The name for the main module. """ main_module = sources[0] + + # Get all the submodule names. submodules = [] for source in sources[1:]: - filename = Path(source).name module_name = Path(source).stem - # Read in the complete interface (.i) file - with open(source, "r") as f: - content = f.read() submodules.append(module_name) - cc_content = self.wrap_file(content, module_name=module_name) - - # Generate the C++ code which Pybind11 will use. - with open(filename.replace(".i", ".cpp"), "w") as f: - f.write(cc_content) with open(main_module, "r") as f: content = f.read() @@ -654,5 +723,5 @@ def wrap(self, sources, main_output): submodules=submodules) # Generate the C++ code which Pybind11 will use. - with open(main_output, "w") as f: + with open(main_module_name, "w") as f: f.write(cc_content) diff --git a/wrap/gtwrap/template_instantiator/helpers.py b/wrap/gtwrap/template_instantiator/helpers.py index b55515dba6..194c6f686c 100644 --- a/wrap/gtwrap/template_instantiator/helpers.py +++ b/wrap/gtwrap/template_instantiator/helpers.py @@ -55,16 +55,14 @@ def instantiate_type( # make a deep copy so that there is no overwriting of original template params ctype = deepcopy(ctype) - # Check if the return type has template parameters + # Check if the return type has template parameters as the typename's name if ctype.typename.instantiations: for idx, instantiation in enumerate(ctype.typename.instantiations): if instantiation.name in template_typenames: template_idx = template_typenames.index(instantiation.name) - ctype.typename.instantiations[ - idx] = instantiations[ # type: ignore - template_idx] + ctype.typename.instantiations[idx].name =\ + instantiations[template_idx] - return ctype str_arg_typename = str(ctype.typename) @@ -125,9 +123,18 @@ def instantiate_type( # Case when 'This' is present in the type namespace, e.g `This::Subclass`. elif 'This' in str_arg_typename: - # Simply get the index of `This` in the namespace and replace it with the instantiated name. - namespace_idx = ctype.typename.namespaces.index('This') - ctype.typename.namespaces[namespace_idx] = cpp_typename.name + # Check if `This` is in the namespaces + if 'This' in ctype.typename.namespaces: + # Simply get the index of `This` in the namespace and + # replace it with the instantiated name. + namespace_idx = ctype.typename.namespaces.index('This') + ctype.typename.namespaces[namespace_idx] = cpp_typename.name + # Else check if it is in the template namespace, e.g vector + else: + for idx, instantiation in enumerate(ctype.typename.instantiations): + if 'This' in instantiation.namespaces: + ctype.typename.instantiations[idx].namespaces = \ + cpp_typename.namespaces + [cpp_typename.name] return ctype else: diff --git a/wrap/scripts/pybind_wrap.py b/wrap/scripts/pybind_wrap.py index c82a1d24c0..5770602439 100644 --- a/wrap/scripts/pybind_wrap.py +++ b/wrap/scripts/pybind_wrap.py @@ -19,7 +19,7 @@ def main(): arg_parser.add_argument("--src", type=str, required=True, - help="Input interface .i/.h file") + help="Input interface .i/.h file(s)") arg_parser.add_argument( "--module_name", type=str, @@ -31,7 +31,7 @@ def main(): "--out", type=str, required=True, - help="Name of the output pybind .cc file", + help="Name of the output pybind .cc file(s)", ) arg_parser.add_argument( "--use-boost", @@ -60,7 +60,10 @@ def main(): ) arg_parser.add_argument("--template", type=str, - help="The module template file") + help="The module template file (e.g. module.tpl).") + arg_parser.add_argument("--is_submodule", + default=False, + action="store_true") args = arg_parser.parse_args() top_module_namespaces = args.top_module_namespaces.split("::") @@ -78,9 +81,13 @@ def main(): module_template=template_content, ) - # Wrap the code and get back the cpp/cc code. - sources = args.src.split(';') - wrapper.wrap(sources, args.out) + if args.is_submodule: + wrapper.wrap_submodule(args.src) + + else: + # Wrap the code and get back the cpp/cc code. + sources = args.src.split(';') + wrapper.wrap(sources, args.out) if __name__ == "__main__": diff --git a/wrap/tests/actual/.gitignore b/wrap/tests/actual/.gitignore new file mode 100644 index 0000000000..7e0359a99d --- /dev/null +++ b/wrap/tests/actual/.gitignore @@ -0,0 +1,2 @@ +./* +!.gitignore diff --git a/wrap/tests/expected/matlab/+gtsam/GeneralSFMFactorCal3Bundler.m b/wrap/tests/expected/matlab/+gtsam/GeneralSFMFactorCal3Bundler.m new file mode 100644 index 0000000000..0ce4051afc --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/GeneralSFMFactorCal3Bundler.m @@ -0,0 +1,31 @@ +%class GeneralSFMFactorCal3Bundler, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +classdef GeneralSFMFactorCal3Bundler < handle + properties + ptr_gtsamGeneralSFMFactorCal3Bundler = 0 + end + methods + function obj = GeneralSFMFactorCal3Bundler(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + special_cases_wrapper(7, my_ptr); + else + error('Arguments do not match any overload of gtsam.GeneralSFMFactorCal3Bundler constructor'); + end + obj.ptr_gtsamGeneralSFMFactorCal3Bundler = my_ptr; + end + + function delete(obj) + special_cases_wrapper(8, obj.ptr_gtsamGeneralSFMFactorCal3Bundler); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/Point3.m b/wrap/tests/expected/matlab/+gtsam/Point3.m index 06d378ac27..b3290faf2f 100644 --- a/wrap/tests/expected/matlab/+gtsam/Point3.m +++ b/wrap/tests/expected/matlab/+gtsam/Point3.m @@ -78,7 +78,7 @@ function delete(obj) error('Arguments do not match any overload of function Point3.StaticFunctionRet'); end - function varargout = StaticFunction(varargin) + function varargout = staticFunction(varargin) % STATICFUNCTION usage: staticFunction() : returns double % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 diff --git a/wrap/tests/expected/matlab/+gtsam/SfmTrack.m b/wrap/tests/expected/matlab/+gtsam/SfmTrack.m new file mode 100644 index 0000000000..428da2706d --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/SfmTrack.m @@ -0,0 +1,31 @@ +%class SfmTrack, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +classdef SfmTrack < handle + properties + ptr_gtsamSfmTrack = 0 + end + methods + function obj = SfmTrack(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + special_cases_wrapper(3, my_ptr); + else + error('Arguments do not match any overload of gtsam.SfmTrack constructor'); + end + obj.ptr_gtsamSfmTrack = my_ptr; + end + + function delete(obj) + special_cases_wrapper(4, obj.ptr_gtsamSfmTrack); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/Values.m b/wrap/tests/expected/matlab/+gtsam/Values.m new file mode 100644 index 0000000000..d85b24b911 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/Values.m @@ -0,0 +1,59 @@ +%class Values, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%Values() +%Values(Values other) +% +%-------Methods------- +%insert(size_t j, Vector vector) : returns void +%insert(size_t j, Matrix matrix) : returns void +% +classdef Values < handle + properties + ptr_gtsamValues = 0 + end + methods + function obj = Values(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + namespaces_wrapper(26, my_ptr); + elseif nargin == 0 + my_ptr = namespaces_wrapper(27); + elseif nargin == 1 && isa(varargin{1},'gtsam.Values') + my_ptr = namespaces_wrapper(28, varargin{1}); + else + error('Arguments do not match any overload of gtsam.Values constructor'); + end + obj.ptr_gtsamValues = my_ptr; + end + + function delete(obj) + namespaces_wrapper(29, obj.ptr_gtsamValues); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + function varargout = insert(this, varargin) + % INSERT usage: insert(size_t j, Vector vector) : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'double') && size(varargin{2},2)==1 + namespaces_wrapper(30, this, varargin{:}); + return + end + % INSERT usage: insert(size_t j, Matrix matrix) : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'double') + namespaces_wrapper(31, this, varargin{:}); + return + end + error('Arguments do not match any overload of function gtsam.Values.insert'); + end + + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/+ns2/ClassA.m b/wrap/tests/expected/matlab/+ns2/ClassA.m index 4640e7cca9..71718ccbab 100644 --- a/wrap/tests/expected/matlab/+ns2/ClassA.m +++ b/wrap/tests/expected/matlab/+ns2/ClassA.m @@ -74,7 +74,7 @@ function delete(obj) end methods(Static = true) - function varargout = Afunction(varargin) + function varargout = afunction(varargin) % AFUNCTION usage: afunction() : returns double % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 diff --git a/wrap/tests/expected/matlab/DefaultFuncInt.m b/wrap/tests/expected/matlab/DefaultFuncInt.m new file mode 100644 index 0000000000..6c9c4116bc --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncInt.m @@ -0,0 +1,10 @@ +function varargout = DefaultFuncInt(varargin) + if length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') + functions_wrapper(8, varargin{:}); + elseif length(varargin) == 1 && isa(varargin{1},'numeric') + functions_wrapper(9, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(10, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncInt'); + end diff --git a/wrap/tests/expected/matlab/DefaultFuncObj.m b/wrap/tests/expected/matlab/DefaultFuncObj.m new file mode 100644 index 0000000000..15d9ba0fa3 --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncObj.m @@ -0,0 +1,8 @@ +function varargout = DefaultFuncObj(varargin) + if length(varargin) == 1 && isa(varargin{1},'gtsam.KeyFormatter') + functions_wrapper(14, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(15, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncObj'); + end diff --git a/wrap/tests/expected/matlab/DefaultFuncString.m b/wrap/tests/expected/matlab/DefaultFuncString.m new file mode 100644 index 0000000000..d26201c152 --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncString.m @@ -0,0 +1,10 @@ +function varargout = DefaultFuncString(varargin) + if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'char') + functions_wrapper(11, varargin{:}); + elseif length(varargin) == 1 && isa(varargin{1},'char') + functions_wrapper(12, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(13, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncString'); + end diff --git a/wrap/tests/expected/matlab/DefaultFuncVector.m b/wrap/tests/expected/matlab/DefaultFuncVector.m new file mode 100644 index 0000000000..344533fad0 --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncVector.m @@ -0,0 +1,10 @@ +function varargout = DefaultFuncVector(varargin) + if length(varargin) == 2 && isa(varargin{1},'std.vectornumeric') && isa(varargin{2},'std.vectorchar') + functions_wrapper(20, varargin{:}); + elseif length(varargin) == 1 && isa(varargin{1},'std.vectornumeric') + functions_wrapper(21, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(22, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncVector'); + end diff --git a/wrap/tests/expected/matlab/DefaultFuncZero.m b/wrap/tests/expected/matlab/DefaultFuncZero.m new file mode 100644 index 0000000000..0ebba2e5c7 --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncZero.m @@ -0,0 +1,12 @@ +function varargout = DefaultFuncZero(varargin) + if length(varargin) == 5 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'numeric') && isa(varargin{5},'logical') + functions_wrapper(16, varargin{:}); + elseif length(varargin) == 4 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'numeric') + functions_wrapper(17, varargin{:}); + elseif length(varargin) == 3 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') + functions_wrapper(18, varargin{:}); + elseif length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') + functions_wrapper(19, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncZero'); + end diff --git a/wrap/tests/expected/matlab/ForwardKinematics.m b/wrap/tests/expected/matlab/ForwardKinematics.m new file mode 100644 index 0000000000..c2ff701c74 --- /dev/null +++ b/wrap/tests/expected/matlab/ForwardKinematics.m @@ -0,0 +1,38 @@ +%class ForwardKinematics, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%ForwardKinematics(Robot robot, string start_link_name, string end_link_name, Values joint_angles, Pose3 l2Tp) +% +classdef ForwardKinematics < handle + properties + ptr_ForwardKinematics = 0 + end + methods + function obj = ForwardKinematics(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + class_wrapper(57, my_ptr); + elseif nargin == 5 && isa(varargin{1},'gtdynamics.Robot') && isa(varargin{2},'char') && isa(varargin{3},'char') && isa(varargin{4},'gtsam.Values') && isa(varargin{5},'gtsam.Pose3') + my_ptr = class_wrapper(58, varargin{1}, varargin{2}, varargin{3}, varargin{4}, varargin{5}); + elseif nargin == 4 && isa(varargin{1},'gtdynamics.Robot') && isa(varargin{2},'char') && isa(varargin{3},'char') && isa(varargin{4},'gtsam.Values') + my_ptr = class_wrapper(59, varargin{1}, varargin{2}, varargin{3}, varargin{4}); + else + error('Arguments do not match any overload of ForwardKinematics constructor'); + end + obj.ptr_ForwardKinematics = my_ptr; + end + + function delete(obj) + class_wrapper(60, obj.ptr_ForwardKinematics); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/ForwardKinematicsFactor.m b/wrap/tests/expected/matlab/ForwardKinematicsFactor.m new file mode 100644 index 0000000000..46aa413928 --- /dev/null +++ b/wrap/tests/expected/matlab/ForwardKinematicsFactor.m @@ -0,0 +1,36 @@ +%class ForwardKinematicsFactor, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +classdef ForwardKinematicsFactor < gtsam.BetweenFactor + properties + ptr_ForwardKinematicsFactor = 0 + end + methods + function obj = ForwardKinematicsFactor(varargin) + if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void'))) && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + if nargin == 2 + my_ptr = varargin{2}; + else + my_ptr = inheritance_wrapper(36, varargin{2}); + end + base_ptr = inheritance_wrapper(35, my_ptr); + else + error('Arguments do not match any overload of ForwardKinematicsFactor constructor'); + end + obj = obj@gtsam.BetweenFactorPose3(uint64(5139824614673773682), base_ptr); + obj.ptr_ForwardKinematicsFactor = my_ptr; + end + + function delete(obj) + inheritance_wrapper(37, obj.ptr_ForwardKinematicsFactor); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/FunDouble.m b/wrap/tests/expected/matlab/FunDouble.m index 78609c7f64..5f432341b1 100644 --- a/wrap/tests/expected/matlab/FunDouble.m +++ b/wrap/tests/expected/matlab/FunDouble.m @@ -3,6 +3,7 @@ % %-------Methods------- %multiTemplatedMethodStringSize_t(double d, string t, size_t u) : returns Fun +%sets() : returns std::map::double> %templatedMethodString(double d, string t) : returns Fun % %-------Static Methods------- @@ -46,11 +47,21 @@ function delete(obj) error('Arguments do not match any overload of function FunDouble.multiTemplatedMethodStringSize_t'); end + function varargout = sets(this, varargin) + % SETS usage: sets() : returns std.mapdoubledouble + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 0 + varargout{1} = class_wrapper(8, this, varargin{:}); + return + end + error('Arguments do not match any overload of function FunDouble.sets'); + end + function varargout = templatedMethodString(this, varargin) % TEMPLATEDMETHODSTRING usage: templatedMethodString(double d, string t) : returns Fun % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'double') && isa(varargin{2},'char') - varargout{1} = class_wrapper(8, this, varargin{:}); + varargout{1} = class_wrapper(9, this, varargin{:}); return end error('Arguments do not match any overload of function FunDouble.templatedMethodString'); @@ -59,22 +70,22 @@ function delete(obj) end methods(Static = true) - function varargout = StaticMethodWithThis(varargin) + function varargout = staticMethodWithThis(varargin) % STATICMETHODWITHTHIS usage: staticMethodWithThis() : returns Fundouble % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - varargout{1} = class_wrapper(9, varargin{:}); + varargout{1} = class_wrapper(10, varargin{:}); return end error('Arguments do not match any overload of function FunDouble.staticMethodWithThis'); end - function varargout = TemplatedStaticMethodInt(varargin) + function varargout = templatedStaticMethodInt(varargin) % TEMPLATEDSTATICMETHODINT usage: templatedStaticMethodInt(int m) : returns double % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'numeric') - varargout{1} = class_wrapper(10, varargin{:}); + varargout{1} = class_wrapper(11, varargin{:}); return end diff --git a/wrap/tests/expected/matlab/FunRange.m b/wrap/tests/expected/matlab/FunRange.m index 1d1a6f7b87..52ee78aa2a 100644 --- a/wrap/tests/expected/matlab/FunRange.m +++ b/wrap/tests/expected/matlab/FunRange.m @@ -52,7 +52,7 @@ function delete(obj) end methods(Static = true) - function varargout = Create(varargin) + function varargout = create(varargin) % CREATE usage: create() : returns FunRange % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 diff --git a/wrap/tests/expected/matlab/MultipleTemplatesIntDouble.m b/wrap/tests/expected/matlab/MultipleTemplatesIntDouble.m index 863d30ee81..ebf263bcb9 100644 --- a/wrap/tests/expected/matlab/MultipleTemplatesIntDouble.m +++ b/wrap/tests/expected/matlab/MultipleTemplatesIntDouble.m @@ -9,7 +9,7 @@ function obj = MultipleTemplatesIntDouble(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(50, my_ptr); + class_wrapper(53, my_ptr); else error('Arguments do not match any overload of MultipleTemplatesIntDouble constructor'); end @@ -17,7 +17,7 @@ end function delete(obj) - class_wrapper(51, obj.ptr_MultipleTemplatesIntDouble); + class_wrapper(54, obj.ptr_MultipleTemplatesIntDouble); end function display(obj), obj.print(''); end diff --git a/wrap/tests/expected/matlab/MultipleTemplatesIntFloat.m b/wrap/tests/expected/matlab/MultipleTemplatesIntFloat.m index b7f1fdac51..02290f0323 100644 --- a/wrap/tests/expected/matlab/MultipleTemplatesIntFloat.m +++ b/wrap/tests/expected/matlab/MultipleTemplatesIntFloat.m @@ -9,7 +9,7 @@ function obj = MultipleTemplatesIntFloat(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(52, my_ptr); + class_wrapper(55, my_ptr); else error('Arguments do not match any overload of MultipleTemplatesIntFloat constructor'); end @@ -17,7 +17,7 @@ end function delete(obj) - class_wrapper(53, obj.ptr_MultipleTemplatesIntFloat); + class_wrapper(56, obj.ptr_MultipleTemplatesIntFloat); end function display(obj), obj.print(''); end diff --git a/wrap/tests/expected/matlab/MyFactorPosePoint2.m b/wrap/tests/expected/matlab/MyFactorPosePoint2.m index 7634ae2cbd..7457fe749b 100644 --- a/wrap/tests/expected/matlab/MyFactorPosePoint2.m +++ b/wrap/tests/expected/matlab/MyFactorPosePoint2.m @@ -15,9 +15,9 @@ function obj = MyFactorPosePoint2(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(63, my_ptr); + class_wrapper(67, my_ptr); elseif nargin == 4 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'gtsam.noiseModel.Base') - my_ptr = class_wrapper(64, varargin{1}, varargin{2}, varargin{3}, varargin{4}); + my_ptr = class_wrapper(68, varargin{1}, varargin{2}, varargin{3}, varargin{4}); else error('Arguments do not match any overload of MyFactorPosePoint2 constructor'); end @@ -25,7 +25,7 @@ end function delete(obj) - class_wrapper(65, obj.ptr_MyFactorPosePoint2); + class_wrapper(69, obj.ptr_MyFactorPosePoint2); end function display(obj), obj.print(''); end @@ -36,7 +36,19 @@ function delete(obj) % PRINT usage: print(string s, KeyFormatter keyFormatter) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'gtsam.KeyFormatter') - class_wrapper(66, this, varargin{:}); + class_wrapper(70, this, varargin{:}); + return + end + % PRINT usage: print(string s) : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 1 && isa(varargin{1},'char') + class_wrapper(71, this, varargin{:}); + return + end + % PRINT usage: print() : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 0 + class_wrapper(72, this, varargin{:}); return end error('Arguments do not match any overload of function MyFactorPosePoint2.print'); diff --git a/wrap/tests/expected/matlab/MyVector12.m b/wrap/tests/expected/matlab/MyVector12.m index 291d0d71ba..53e554a100 100644 --- a/wrap/tests/expected/matlab/MyVector12.m +++ b/wrap/tests/expected/matlab/MyVector12.m @@ -12,9 +12,9 @@ function obj = MyVector12(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(47, my_ptr); + class_wrapper(50, my_ptr); elseif nargin == 0 - my_ptr = class_wrapper(48); + my_ptr = class_wrapper(51); else error('Arguments do not match any overload of MyVector12 constructor'); end @@ -22,7 +22,7 @@ end function delete(obj) - class_wrapper(49, obj.ptr_MyVector12); + class_wrapper(52, obj.ptr_MyVector12); end function display(obj), obj.print(''); end diff --git a/wrap/tests/expected/matlab/MyVector3.m b/wrap/tests/expected/matlab/MyVector3.m index 3051dc2e23..0f6ea84aba 100644 --- a/wrap/tests/expected/matlab/MyVector3.m +++ b/wrap/tests/expected/matlab/MyVector3.m @@ -12,9 +12,9 @@ function obj = MyVector3(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(44, my_ptr); + class_wrapper(47, my_ptr); elseif nargin == 0 - my_ptr = class_wrapper(45); + my_ptr = class_wrapper(48); else error('Arguments do not match any overload of MyVector3 constructor'); end @@ -22,7 +22,7 @@ end function delete(obj) - class_wrapper(46, obj.ptr_MyVector3); + class_wrapper(49, obj.ptr_MyVector3); end function display(obj), obj.print(''); end diff --git a/wrap/tests/expected/matlab/PrimitiveRefDouble.m b/wrap/tests/expected/matlab/PrimitiveRefDouble.m index dd0a6d2daf..e1039e567d 100644 --- a/wrap/tests/expected/matlab/PrimitiveRefDouble.m +++ b/wrap/tests/expected/matlab/PrimitiveRefDouble.m @@ -19,9 +19,9 @@ function obj = PrimitiveRefDouble(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(40, my_ptr); + class_wrapper(43, my_ptr); elseif nargin == 0 - my_ptr = class_wrapper(41); + my_ptr = class_wrapper(44); else error('Arguments do not match any overload of PrimitiveRefDouble constructor'); end @@ -29,7 +29,7 @@ end function delete(obj) - class_wrapper(42, obj.ptr_PrimitiveRefDouble); + class_wrapper(45, obj.ptr_PrimitiveRefDouble); end function display(obj), obj.print(''); end @@ -43,7 +43,7 @@ function delete(obj) % BRUTAL usage: Brutal(double t) : returns PrimitiveRefdouble % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - varargout{1} = class_wrapper(43, varargin{:}); + varargout{1} = class_wrapper(46, varargin{:}); return end diff --git a/wrap/tests/expected/matlab/ScopedTemplateResult.m b/wrap/tests/expected/matlab/ScopedTemplateResult.m new file mode 100644 index 0000000000..8cb8ed7d04 --- /dev/null +++ b/wrap/tests/expected/matlab/ScopedTemplateResult.m @@ -0,0 +1,36 @@ +%class ScopedTemplateResult, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%ScopedTemplateResult(Result::Value arg) +% +classdef ScopedTemplateResult < handle + properties + ptr_ScopedTemplateResult = 0 + end + methods + function obj = ScopedTemplateResult(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + template_wrapper(6, my_ptr); + elseif nargin == 1 && isa(varargin{1},'Result::Value') + my_ptr = template_wrapper(7, varargin{1}); + else + error('Arguments do not match any overload of ScopedTemplateResult constructor'); + end + obj.ptr_ScopedTemplateResult = my_ptr; + end + + function delete(obj) + template_wrapper(8, obj.ptr_ScopedTemplateResult); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/TemplatedConstructor.m b/wrap/tests/expected/matlab/TemplatedConstructor.m new file mode 100644 index 0000000000..70beb26ce0 --- /dev/null +++ b/wrap/tests/expected/matlab/TemplatedConstructor.m @@ -0,0 +1,45 @@ +%class TemplatedConstructor, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%TemplatedConstructor() +%TemplatedConstructor(string arg) +%TemplatedConstructor(int arg) +%TemplatedConstructor(double arg) +% +classdef TemplatedConstructor < handle + properties + ptr_TemplatedConstructor = 0 + end + methods + function obj = TemplatedConstructor(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + template_wrapper(0, my_ptr); + elseif nargin == 0 + my_ptr = template_wrapper(1); + elseif nargin == 1 && isa(varargin{1},'char') + my_ptr = template_wrapper(2, varargin{1}); + elseif nargin == 1 && isa(varargin{1},'numeric') + my_ptr = template_wrapper(3, varargin{1}); + elseif nargin == 1 && isa(varargin{1},'double') + my_ptr = template_wrapper(4, varargin{1}); + else + error('Arguments do not match any overload of TemplatedConstructor constructor'); + end + obj.ptr_TemplatedConstructor = my_ptr; + end + + function delete(obj) + template_wrapper(5, obj.ptr_TemplatedConstructor); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/TemplatedFunctionRot3.m b/wrap/tests/expected/matlab/TemplatedFunctionRot3.m index 4216201b49..eb5cb4abea 100644 --- a/wrap/tests/expected/matlab/TemplatedFunctionRot3.m +++ b/wrap/tests/expected/matlab/TemplatedFunctionRot3.m @@ -1,6 +1,6 @@ function varargout = TemplatedFunctionRot3(varargin) if length(varargin) == 1 && isa(varargin{1},'gtsam.Rot3') - functions_wrapper(14, varargin{:}); + functions_wrapper(25, varargin{:}); else error('Arguments do not match any overload of function TemplatedFunctionRot3'); end diff --git a/wrap/tests/expected/matlab/Test.m b/wrap/tests/expected/matlab/Test.m index 8569120c5b..66ba4721c0 100644 --- a/wrap/tests/expected/matlab/Test.m +++ b/wrap/tests/expected/matlab/Test.m @@ -11,6 +11,7 @@ %create_ptrs() : returns pair< Test, Test > %get_container() : returns std::vector %lambda() : returns void +%markdown(KeyFormatter keyFormatter) : returns string %print() : returns void %return_Point2Ptr(bool value) : returns Point2 %return_Test(Test value) : returns Test @@ -40,11 +41,11 @@ function obj = Test(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(11, my_ptr); + class_wrapper(12, my_ptr); elseif nargin == 0 - my_ptr = class_wrapper(12); + my_ptr = class_wrapper(13); elseif nargin == 2 && isa(varargin{1},'double') && isa(varargin{2},'double') - my_ptr = class_wrapper(13, varargin{1}, varargin{2}); + my_ptr = class_wrapper(14, varargin{1}, varargin{2}); else error('Arguments do not match any overload of Test constructor'); end @@ -52,7 +53,7 @@ end function delete(obj) - class_wrapper(14, obj.ptr_Test); + class_wrapper(15, obj.ptr_Test); end function display(obj), obj.print(''); end @@ -63,7 +64,7 @@ function delete(obj) % ARG_EIGENCONSTREF usage: arg_EigenConstRef(Matrix value) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - class_wrapper(15, this, varargin{:}); + class_wrapper(16, this, varargin{:}); return end error('Arguments do not match any overload of function Test.arg_EigenConstRef'); @@ -73,7 +74,7 @@ function delete(obj) % CREATE_MIXEDPTRS usage: create_MixedPtrs() : returns pair< Test, Test > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - [ varargout{1} varargout{2} ] = class_wrapper(16, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(17, this, varargin{:}); return end error('Arguments do not match any overload of function Test.create_MixedPtrs'); @@ -83,7 +84,7 @@ function delete(obj) % CREATE_PTRS usage: create_ptrs() : returns pair< Test, Test > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - [ varargout{1} varargout{2} ] = class_wrapper(17, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(18, this, varargin{:}); return end error('Arguments do not match any overload of function Test.create_ptrs'); @@ -93,7 +94,7 @@ function delete(obj) % GET_CONTAINER usage: get_container() : returns std.vectorTest % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - varargout{1} = class_wrapper(18, this, varargin{:}); + varargout{1} = class_wrapper(19, this, varargin{:}); return end error('Arguments do not match any overload of function Test.get_container'); @@ -103,17 +104,33 @@ function delete(obj) % LAMBDA usage: lambda() : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - class_wrapper(19, this, varargin{:}); + class_wrapper(20, this, varargin{:}); return end error('Arguments do not match any overload of function Test.lambda'); end + function varargout = markdown(this, varargin) + % MARKDOWN usage: markdown(KeyFormatter keyFormatter) : returns string + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 1 && isa(varargin{1},'gtsam.KeyFormatter') + varargout{1} = class_wrapper(21, this, varargin{:}); + return + end + % MARKDOWN usage: markdown() : returns string + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 0 + varargout{1} = class_wrapper(22, this, varargin{:}); + return + end + error('Arguments do not match any overload of function Test.markdown'); + end + function varargout = print(this, varargin) % PRINT usage: print() : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - class_wrapper(20, this, varargin{:}); + class_wrapper(23, this, varargin{:}); return end error('Arguments do not match any overload of function Test.print'); @@ -123,7 +140,7 @@ function delete(obj) % RETURN_POINT2PTR usage: return_Point2Ptr(bool value) : returns Point2 % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'logical') - varargout{1} = class_wrapper(21, this, varargin{:}); + varargout{1} = class_wrapper(24, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_Point2Ptr'); @@ -133,7 +150,7 @@ function delete(obj) % RETURN_TEST usage: return_Test(Test value) : returns Test % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'Test') - varargout{1} = class_wrapper(22, this, varargin{:}); + varargout{1} = class_wrapper(25, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_Test'); @@ -143,7 +160,7 @@ function delete(obj) % RETURN_TESTPTR usage: return_TestPtr(Test value) : returns Test % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'Test') - varargout{1} = class_wrapper(23, this, varargin{:}); + varargout{1} = class_wrapper(26, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_TestPtr'); @@ -153,7 +170,7 @@ function delete(obj) % RETURN_BOOL usage: return_bool(bool value) : returns bool % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'logical') - varargout{1} = class_wrapper(24, this, varargin{:}); + varargout{1} = class_wrapper(27, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_bool'); @@ -163,7 +180,7 @@ function delete(obj) % RETURN_DOUBLE usage: return_double(double value) : returns double % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - varargout{1} = class_wrapper(25, this, varargin{:}); + varargout{1} = class_wrapper(28, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_double'); @@ -173,7 +190,7 @@ function delete(obj) % RETURN_FIELD usage: return_field(Test t) : returns bool % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'Test') - varargout{1} = class_wrapper(26, this, varargin{:}); + varargout{1} = class_wrapper(29, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_field'); @@ -183,7 +200,7 @@ function delete(obj) % RETURN_INT usage: return_int(int value) : returns int % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'numeric') - varargout{1} = class_wrapper(27, this, varargin{:}); + varargout{1} = class_wrapper(30, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_int'); @@ -193,7 +210,7 @@ function delete(obj) % RETURN_MATRIX1 usage: return_matrix1(Matrix value) : returns Matrix % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - varargout{1} = class_wrapper(28, this, varargin{:}); + varargout{1} = class_wrapper(31, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_matrix1'); @@ -203,7 +220,7 @@ function delete(obj) % RETURN_MATRIX2 usage: return_matrix2(Matrix value) : returns Matrix % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - varargout{1} = class_wrapper(29, this, varargin{:}); + varargout{1} = class_wrapper(32, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_matrix2'); @@ -213,13 +230,13 @@ function delete(obj) % RETURN_PAIR usage: return_pair(Vector v, Matrix A) : returns pair< Vector, Matrix > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'double') && size(varargin{1},2)==1 && isa(varargin{2},'double') - [ varargout{1} varargout{2} ] = class_wrapper(30, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(33, this, varargin{:}); return end % RETURN_PAIR usage: return_pair(Vector v) : returns pair< Vector, Matrix > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') && size(varargin{1},2)==1 - [ varargout{1} varargout{2} ] = class_wrapper(31, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(34, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_pair'); @@ -229,7 +246,7 @@ function delete(obj) % RETURN_PTRS usage: return_ptrs(Test p1, Test p2) : returns pair< Test, Test > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'Test') && isa(varargin{2},'Test') - [ varargout{1} varargout{2} ] = class_wrapper(32, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(35, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_ptrs'); @@ -239,7 +256,7 @@ function delete(obj) % RETURN_SIZE_T usage: return_size_t(size_t value) : returns size_t % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'numeric') - varargout{1} = class_wrapper(33, this, varargin{:}); + varargout{1} = class_wrapper(36, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_size_t'); @@ -249,7 +266,7 @@ function delete(obj) % RETURN_STRING usage: return_string(string value) : returns string % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'char') - varargout{1} = class_wrapper(34, this, varargin{:}); + varargout{1} = class_wrapper(37, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_string'); @@ -259,7 +276,7 @@ function delete(obj) % RETURN_VECTOR1 usage: return_vector1(Vector value) : returns Vector % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') && size(varargin{1},2)==1 - varargout{1} = class_wrapper(35, this, varargin{:}); + varargout{1} = class_wrapper(38, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_vector1'); @@ -269,7 +286,7 @@ function delete(obj) % RETURN_VECTOR2 usage: return_vector2(Vector value) : returns Vector % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') && size(varargin{1},2)==1 - varargout{1} = class_wrapper(36, this, varargin{:}); + varargout{1} = class_wrapper(39, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_vector2'); @@ -279,19 +296,19 @@ function delete(obj) % SET_CONTAINER usage: set_container(vector container) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'std.vectorTest') - class_wrapper(37, this, varargin{:}); + class_wrapper(40, this, varargin{:}); return end % SET_CONTAINER usage: set_container(vector container) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'std.vectorTest') - class_wrapper(38, this, varargin{:}); + class_wrapper(41, this, varargin{:}); return end % SET_CONTAINER usage: set_container(vector container) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'std.vectorTest') - class_wrapper(39, this, varargin{:}); + class_wrapper(42, this, varargin{:}); return end error('Arguments do not match any overload of function Test.set_container'); diff --git a/wrap/tests/expected/matlab/class_wrapper.cpp b/wrap/tests/expected/matlab/class_wrapper.cpp index df6cb33071..03a25c358f 100644 --- a/wrap/tests/expected/matlab/class_wrapper.cpp +++ b/wrap/tests/expected/matlab/class_wrapper.cpp @@ -145,7 +145,7 @@ void _class_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_class_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -180,9 +180,9 @@ void FunRange_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxA Collector_FunRange::iterator item; item = collector_FunRange.find(self); if(item != collector_FunRange.end()) { - delete self; collector_FunRange.erase(item); } + delete self; } void FunRange_range_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -216,9 +216,9 @@ void FunDouble_deconstructor_6(int nargout, mxArray *out[], int nargin, const mx Collector_FunDouble::iterator item; item = collector_FunDouble.find(self); if(item != collector_FunDouble.end()) { - delete self; collector_FunDouble.erase(item); } + delete self; } void FunDouble_multiTemplatedMethod_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -231,7 +231,14 @@ void FunDouble_multiTemplatedMethod_7(int nargout, mxArray *out[], int nargin, c out[0] = wrap_shared_ptr(boost::make_shared>(obj->multiTemplatedMethod(d,t,u)),"Fun", false); } -void FunDouble_templatedMethod_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void FunDouble_sets_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("sets",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr>(in[0], "ptr_FunDouble"); + out[0] = wrap_shared_ptr(boost::make_shared::double>>(obj->sets()),"std.mapdoubledouble", false); +} + +void FunDouble_templatedMethod_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("templatedMethodString",nargout,nargin-1,2); auto obj = unwrap_shared_ptr>(in[0], "ptr_FunDouble"); @@ -240,20 +247,20 @@ void FunDouble_templatedMethod_8(int nargout, mxArray *out[], int nargin, const out[0] = wrap_shared_ptr(boost::make_shared>(obj->templatedMethod(d,t)),"Fun", false); } -void FunDouble_staticMethodWithThis_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void FunDouble_staticMethodWithThis_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("FunDouble.staticMethodWithThis",nargout,nargin,0); + checkArguments("Fun.staticMethodWithThis",nargout,nargin,0); out[0] = wrap_shared_ptr(boost::make_shared>(Fun::staticMethodWithThis()),"Fundouble", false); } -void FunDouble_templatedStaticMethodInt_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void FunDouble_templatedStaticMethodInt_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("FunDouble.templatedStaticMethodInt",nargout,nargin,1); + checkArguments("Fun.templatedStaticMethodInt",nargout,nargin,1); int m = unwrap< int >(in[0]); - out[0] = wrap< double >(Fun::templatedStaticMethodInt(m)); + out[0] = wrap< double >(Fun::templatedStaticMethod(m)); } -void Test_collectorInsertAndMakeBase_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_collectorInsertAndMakeBase_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -262,7 +269,7 @@ void Test_collectorInsertAndMakeBase_11(int nargout, mxArray *out[], int nargin, collector_Test.insert(self); } -void Test_constructor_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -273,7 +280,7 @@ void Test_constructor_12(int nargout, mxArray *out[], int nargin, const mxArray *reinterpret_cast (mxGetData(out[0])) = self; } -void Test_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_constructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -286,7 +293,7 @@ void Test_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *reinterpret_cast (mxGetData(out[0])) = self; } -void Test_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_deconstructor_15(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr Shared; checkArguments("delete_Test",nargout,nargin,1); @@ -294,12 +301,12 @@ void Test_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArra Collector_Test::iterator item; item = collector_Test.find(self); if(item != collector_Test.end()) { - delete self; collector_Test.erase(item); } + delete self; } -void Test_arg_EigenConstRef_15(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_arg_EigenConstRef_16(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("arg_EigenConstRef",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -307,7 +314,7 @@ void Test_arg_EigenConstRef_15(int nargout, mxArray *out[], int nargin, const mx obj->arg_EigenConstRef(value); } -void Test_create_MixedPtrs_16(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_create_MixedPtrs_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("create_MixedPtrs",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -316,7 +323,7 @@ void Test_create_MixedPtrs_16(int nargout, mxArray *out[], int nargin, const mxA out[1] = wrap_shared_ptr(pairResult.second,"Test", false); } -void Test_create_ptrs_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_create_ptrs_18(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("create_ptrs",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -325,28 +332,43 @@ void Test_create_ptrs_17(int nargout, mxArray *out[], int nargin, const mxArray out[1] = wrap_shared_ptr(pairResult.second,"Test", false); } -void Test_get_container_18(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_get_container_19(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("get_container",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); out[0] = wrap_shared_ptr(boost::make_shared>(obj->get_container()),"std.vectorTest", false); } -void Test_lambda_19(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_lambda_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("lambda",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); obj->lambda(); } -void Test_print_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_markdown_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("markdown",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); + gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[1], "ptr_gtsamKeyFormatter"); + out[0] = wrap< string >(obj->markdown(keyFormatter)); +} + +void Test_markdown_22(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("markdown",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); + out[0] = wrap< string >(obj->markdown(gtsam::DefaultKeyFormatter)); +} + +void Test_print_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("print",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); obj->print(); } -void Test_return_Point2Ptr_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_Point2Ptr_24(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_Point2Ptr",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -357,7 +379,7 @@ void Test_return_Point2Ptr_21(int nargout, mxArray *out[], int nargin, const mxA } } -void Test_return_Test_22(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_Test_25(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_Test",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -365,7 +387,7 @@ void Test_return_Test_22(int nargout, mxArray *out[], int nargin, const mxArray out[0] = wrap_shared_ptr(boost::make_shared(obj->return_Test(value)),"Test", false); } -void Test_return_TestPtr_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_TestPtr_26(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_TestPtr",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -373,7 +395,7 @@ void Test_return_TestPtr_23(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap_shared_ptr(obj->return_TestPtr(value),"Test", false); } -void Test_return_bool_24(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_bool_27(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_bool",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -381,7 +403,7 @@ void Test_return_bool_24(int nargout, mxArray *out[], int nargin, const mxArray out[0] = wrap< bool >(obj->return_bool(value)); } -void Test_return_double_25(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_double_28(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_double",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -389,7 +411,7 @@ void Test_return_double_25(int nargout, mxArray *out[], int nargin, const mxArra out[0] = wrap< double >(obj->return_double(value)); } -void Test_return_field_26(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_field_29(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_field",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -397,7 +419,7 @@ void Test_return_field_26(int nargout, mxArray *out[], int nargin, const mxArray out[0] = wrap< bool >(obj->return_field(t)); } -void Test_return_int_27(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_int_30(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_int",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -405,7 +427,7 @@ void Test_return_int_27(int nargout, mxArray *out[], int nargin, const mxArray * out[0] = wrap< int >(obj->return_int(value)); } -void Test_return_matrix1_28(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_matrix1_31(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_matrix1",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -413,7 +435,7 @@ void Test_return_matrix1_28(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap< Matrix >(obj->return_matrix1(value)); } -void Test_return_matrix2_29(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_matrix2_32(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_matrix2",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -421,7 +443,7 @@ void Test_return_matrix2_29(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap< Matrix >(obj->return_matrix2(value)); } -void Test_return_pair_30(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_pair_33(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_pair",nargout,nargin-1,2); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -432,7 +454,7 @@ void Test_return_pair_30(int nargout, mxArray *out[], int nargin, const mxArray out[1] = wrap< Matrix >(pairResult.second); } -void Test_return_pair_31(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_pair_34(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_pair",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -442,7 +464,7 @@ void Test_return_pair_31(int nargout, mxArray *out[], int nargin, const mxArray out[1] = wrap< Matrix >(pairResult.second); } -void Test_return_ptrs_32(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_ptrs_35(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_ptrs",nargout,nargin-1,2); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -453,7 +475,7 @@ void Test_return_ptrs_32(int nargout, mxArray *out[], int nargin, const mxArray out[1] = wrap_shared_ptr(pairResult.second,"Test", false); } -void Test_return_size_t_33(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_size_t_36(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_size_t",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -461,7 +483,7 @@ void Test_return_size_t_33(int nargout, mxArray *out[], int nargin, const mxArra out[0] = wrap< size_t >(obj->return_size_t(value)); } -void Test_return_string_34(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_string_37(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_string",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -469,7 +491,7 @@ void Test_return_string_34(int nargout, mxArray *out[], int nargin, const mxArra out[0] = wrap< string >(obj->return_string(value)); } -void Test_return_vector1_35(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_vector1_38(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_vector1",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -477,7 +499,7 @@ void Test_return_vector1_35(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap< Vector >(obj->return_vector1(value)); } -void Test_return_vector2_36(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_vector2_39(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_vector2",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -485,7 +507,7 @@ void Test_return_vector2_36(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap< Vector >(obj->return_vector2(value)); } -void Test_set_container_37(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_set_container_40(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("set_container",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -493,7 +515,7 @@ void Test_set_container_37(int nargout, mxArray *out[], int nargin, const mxArra obj->set_container(*container); } -void Test_set_container_38(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_set_container_41(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("set_container",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -501,7 +523,7 @@ void Test_set_container_38(int nargout, mxArray *out[], int nargin, const mxArra obj->set_container(*container); } -void Test_set_container_39(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_set_container_42(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("set_container",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -509,7 +531,7 @@ void Test_set_container_39(int nargout, mxArray *out[], int nargin, const mxArra obj->set_container(*container); } -void PrimitiveRefDouble_collectorInsertAndMakeBase_40(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void PrimitiveRefDouble_collectorInsertAndMakeBase_43(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -518,7 +540,7 @@ void PrimitiveRefDouble_collectorInsertAndMakeBase_40(int nargout, mxArray *out[ collector_PrimitiveRefDouble.insert(self); } -void PrimitiveRefDouble_constructor_41(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void PrimitiveRefDouble_constructor_44(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -529,7 +551,7 @@ void PrimitiveRefDouble_constructor_41(int nargout, mxArray *out[], int nargin, *reinterpret_cast (mxGetData(out[0])) = self; } -void PrimitiveRefDouble_deconstructor_42(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void PrimitiveRefDouble_deconstructor_45(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_PrimitiveRefDouble",nargout,nargin,1); @@ -537,19 +559,19 @@ void PrimitiveRefDouble_deconstructor_42(int nargout, mxArray *out[], int nargin Collector_PrimitiveRefDouble::iterator item; item = collector_PrimitiveRefDouble.find(self); if(item != collector_PrimitiveRefDouble.end()) { - delete self; collector_PrimitiveRefDouble.erase(item); } + delete self; } -void PrimitiveRefDouble_Brutal_43(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void PrimitiveRefDouble_Brutal_46(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("PrimitiveRefDouble.Brutal",nargout,nargin,1); + checkArguments("PrimitiveRef.Brutal",nargout,nargin,1); double t = unwrap< double >(in[0]); out[0] = wrap_shared_ptr(boost::make_shared>(PrimitiveRef::Brutal(t)),"PrimitiveRefdouble", false); } -void MyVector3_collectorInsertAndMakeBase_44(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector3_collectorInsertAndMakeBase_47(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -558,7 +580,7 @@ void MyVector3_collectorInsertAndMakeBase_44(int nargout, mxArray *out[], int na collector_MyVector3.insert(self); } -void MyVector3_constructor_45(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector3_constructor_48(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -569,7 +591,7 @@ void MyVector3_constructor_45(int nargout, mxArray *out[], int nargin, const mxA *reinterpret_cast (mxGetData(out[0])) = self; } -void MyVector3_deconstructor_46(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector3_deconstructor_49(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MyVector3",nargout,nargin,1); @@ -577,12 +599,12 @@ void MyVector3_deconstructor_46(int nargout, mxArray *out[], int nargin, const m Collector_MyVector3::iterator item; item = collector_MyVector3.find(self); if(item != collector_MyVector3.end()) { - delete self; collector_MyVector3.erase(item); } + delete self; } -void MyVector12_collectorInsertAndMakeBase_47(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector12_collectorInsertAndMakeBase_50(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -591,7 +613,7 @@ void MyVector12_collectorInsertAndMakeBase_47(int nargout, mxArray *out[], int n collector_MyVector12.insert(self); } -void MyVector12_constructor_48(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector12_constructor_51(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -602,7 +624,7 @@ void MyVector12_constructor_48(int nargout, mxArray *out[], int nargin, const mx *reinterpret_cast (mxGetData(out[0])) = self; } -void MyVector12_deconstructor_49(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector12_deconstructor_52(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MyVector12",nargout,nargin,1); @@ -610,12 +632,12 @@ void MyVector12_deconstructor_49(int nargout, mxArray *out[], int nargin, const Collector_MyVector12::iterator item; item = collector_MyVector12.find(self); if(item != collector_MyVector12.end()) { - delete self; collector_MyVector12.erase(item); } + delete self; } -void MultipleTemplatesIntDouble_collectorInsertAndMakeBase_50(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MultipleTemplatesIntDouble_collectorInsertAndMakeBase_53(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -624,7 +646,7 @@ void MultipleTemplatesIntDouble_collectorInsertAndMakeBase_50(int nargout, mxArr collector_MultipleTemplatesIntDouble.insert(self); } -void MultipleTemplatesIntDouble_deconstructor_51(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MultipleTemplatesIntDouble_deconstructor_54(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MultipleTemplatesIntDouble",nargout,nargin,1); @@ -632,12 +654,12 @@ void MultipleTemplatesIntDouble_deconstructor_51(int nargout, mxArray *out[], in Collector_MultipleTemplatesIntDouble::iterator item; item = collector_MultipleTemplatesIntDouble.find(self); if(item != collector_MultipleTemplatesIntDouble.end()) { - delete self; collector_MultipleTemplatesIntDouble.erase(item); } + delete self; } -void MultipleTemplatesIntFloat_collectorInsertAndMakeBase_52(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MultipleTemplatesIntFloat_collectorInsertAndMakeBase_55(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -646,7 +668,7 @@ void MultipleTemplatesIntFloat_collectorInsertAndMakeBase_52(int nargout, mxArra collector_MultipleTemplatesIntFloat.insert(self); } -void MultipleTemplatesIntFloat_deconstructor_53(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MultipleTemplatesIntFloat_deconstructor_56(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MultipleTemplatesIntFloat",nargout,nargin,1); @@ -654,12 +676,12 @@ void MultipleTemplatesIntFloat_deconstructor_53(int nargout, mxArray *out[], int Collector_MultipleTemplatesIntFloat::iterator item; item = collector_MultipleTemplatesIntFloat.find(self); if(item != collector_MultipleTemplatesIntFloat.end()) { - delete self; collector_MultipleTemplatesIntFloat.erase(item); } + delete self; } -void ForwardKinematics_collectorInsertAndMakeBase_54(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void ForwardKinematics_collectorInsertAndMakeBase_57(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -668,7 +690,7 @@ void ForwardKinematics_collectorInsertAndMakeBase_54(int nargout, mxArray *out[] collector_ForwardKinematics.insert(self); } -void ForwardKinematics_constructor_55(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void ForwardKinematics_constructor_58(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -684,7 +706,22 @@ void ForwardKinematics_constructor_55(int nargout, mxArray *out[], int nargin, c *reinterpret_cast (mxGetData(out[0])) = self; } -void ForwardKinematics_deconstructor_56(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void ForwardKinematics_constructor_59(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + + gtdynamics::Robot& robot = *unwrap_shared_ptr< gtdynamics::Robot >(in[0], "ptr_gtdynamicsRobot"); + string& start_link_name = *unwrap_shared_ptr< string >(in[1], "ptr_string"); + string& end_link_name = *unwrap_shared_ptr< string >(in[2], "ptr_string"); + gtsam::Values& joint_angles = *unwrap_shared_ptr< gtsam::Values >(in[3], "ptr_gtsamValues"); + Shared *self = new Shared(new ForwardKinematics(robot,start_link_name,end_link_name,joint_angles,gtsam::Pose3())); + collector_ForwardKinematics.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void ForwardKinematics_deconstructor_60(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr Shared; checkArguments("delete_ForwardKinematics",nargout,nargin,1); @@ -692,12 +729,12 @@ void ForwardKinematics_deconstructor_56(int nargout, mxArray *out[], int nargin, Collector_ForwardKinematics::iterator item; item = collector_ForwardKinematics.find(self); if(item != collector_ForwardKinematics.end()) { - delete self; collector_ForwardKinematics.erase(item); } + delete self; } -void TemplatedConstructor_collectorInsertAndMakeBase_57(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_collectorInsertAndMakeBase_61(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -706,7 +743,7 @@ void TemplatedConstructor_collectorInsertAndMakeBase_57(int nargout, mxArray *ou collector_TemplatedConstructor.insert(self); } -void TemplatedConstructor_constructor_58(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_constructor_62(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -717,7 +754,7 @@ void TemplatedConstructor_constructor_58(int nargout, mxArray *out[], int nargin *reinterpret_cast (mxGetData(out[0])) = self; } -void TemplatedConstructor_constructor_59(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_constructor_63(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -729,7 +766,7 @@ void TemplatedConstructor_constructor_59(int nargout, mxArray *out[], int nargin *reinterpret_cast (mxGetData(out[0])) = self; } -void TemplatedConstructor_constructor_60(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_constructor_64(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -741,7 +778,7 @@ void TemplatedConstructor_constructor_60(int nargout, mxArray *out[], int nargin *reinterpret_cast (mxGetData(out[0])) = self; } -void TemplatedConstructor_constructor_61(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_constructor_65(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -753,7 +790,7 @@ void TemplatedConstructor_constructor_61(int nargout, mxArray *out[], int nargin *reinterpret_cast (mxGetData(out[0])) = self; } -void TemplatedConstructor_deconstructor_62(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_deconstructor_66(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr Shared; checkArguments("delete_TemplatedConstructor",nargout,nargin,1); @@ -761,12 +798,12 @@ void TemplatedConstructor_deconstructor_62(int nargout, mxArray *out[], int narg Collector_TemplatedConstructor::iterator item; item = collector_TemplatedConstructor.find(self); if(item != collector_TemplatedConstructor.end()) { - delete self; collector_TemplatedConstructor.erase(item); } + delete self; } -void MyFactorPosePoint2_collectorInsertAndMakeBase_63(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_collectorInsertAndMakeBase_67(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -775,7 +812,7 @@ void MyFactorPosePoint2_collectorInsertAndMakeBase_63(int nargout, mxArray *out[ collector_MyFactorPosePoint2.insert(self); } -void MyFactorPosePoint2_constructor_64(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_constructor_68(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -790,7 +827,7 @@ void MyFactorPosePoint2_constructor_64(int nargout, mxArray *out[], int nargin, *reinterpret_cast (mxGetData(out[0])) = self; } -void MyFactorPosePoint2_deconstructor_65(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_deconstructor_69(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MyFactorPosePoint2",nargout,nargin,1); @@ -798,12 +835,12 @@ void MyFactorPosePoint2_deconstructor_65(int nargout, mxArray *out[], int nargin Collector_MyFactorPosePoint2::iterator item; item = collector_MyFactorPosePoint2.find(self); if(item != collector_MyFactorPosePoint2.end()) { - delete self; collector_MyFactorPosePoint2.erase(item); } + delete self; } -void MyFactorPosePoint2_print_66(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_print_70(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("print",nargout,nargin-1,2); auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); @@ -812,6 +849,21 @@ void MyFactorPosePoint2_print_66(int nargout, mxArray *out[], int nargin, const obj->print(s,keyFormatter); } +void MyFactorPosePoint2_print_71(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("print",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); + string& s = *unwrap_shared_ptr< string >(in[1], "ptr_string"); + obj->print(s,gtsam::DefaultKeyFormatter); +} + +void MyFactorPosePoint2_print_72(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("print",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); + obj->print("factor: ",gtsam::DefaultKeyFormatter); +} + void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { @@ -849,181 +901,199 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) FunDouble_multiTemplatedMethod_7(nargout, out, nargin-1, in+1); break; case 8: - FunDouble_templatedMethod_8(nargout, out, nargin-1, in+1); + FunDouble_sets_8(nargout, out, nargin-1, in+1); break; case 9: - FunDouble_staticMethodWithThis_9(nargout, out, nargin-1, in+1); + FunDouble_templatedMethod_9(nargout, out, nargin-1, in+1); break; case 10: - FunDouble_templatedStaticMethodInt_10(nargout, out, nargin-1, in+1); + FunDouble_staticMethodWithThis_10(nargout, out, nargin-1, in+1); break; case 11: - Test_collectorInsertAndMakeBase_11(nargout, out, nargin-1, in+1); + FunDouble_templatedStaticMethodInt_11(nargout, out, nargin-1, in+1); break; case 12: - Test_constructor_12(nargout, out, nargin-1, in+1); + Test_collectorInsertAndMakeBase_12(nargout, out, nargin-1, in+1); break; case 13: Test_constructor_13(nargout, out, nargin-1, in+1); break; case 14: - Test_deconstructor_14(nargout, out, nargin-1, in+1); + Test_constructor_14(nargout, out, nargin-1, in+1); break; case 15: - Test_arg_EigenConstRef_15(nargout, out, nargin-1, in+1); + Test_deconstructor_15(nargout, out, nargin-1, in+1); break; case 16: - Test_create_MixedPtrs_16(nargout, out, nargin-1, in+1); + Test_arg_EigenConstRef_16(nargout, out, nargin-1, in+1); break; case 17: - Test_create_ptrs_17(nargout, out, nargin-1, in+1); + Test_create_MixedPtrs_17(nargout, out, nargin-1, in+1); break; case 18: - Test_get_container_18(nargout, out, nargin-1, in+1); + Test_create_ptrs_18(nargout, out, nargin-1, in+1); break; case 19: - Test_lambda_19(nargout, out, nargin-1, in+1); + Test_get_container_19(nargout, out, nargin-1, in+1); break; case 20: - Test_print_20(nargout, out, nargin-1, in+1); + Test_lambda_20(nargout, out, nargin-1, in+1); break; case 21: - Test_return_Point2Ptr_21(nargout, out, nargin-1, in+1); + Test_markdown_21(nargout, out, nargin-1, in+1); break; case 22: - Test_return_Test_22(nargout, out, nargin-1, in+1); + Test_markdown_22(nargout, out, nargin-1, in+1); break; case 23: - Test_return_TestPtr_23(nargout, out, nargin-1, in+1); + Test_print_23(nargout, out, nargin-1, in+1); break; case 24: - Test_return_bool_24(nargout, out, nargin-1, in+1); + Test_return_Point2Ptr_24(nargout, out, nargin-1, in+1); break; case 25: - Test_return_double_25(nargout, out, nargin-1, in+1); + Test_return_Test_25(nargout, out, nargin-1, in+1); break; case 26: - Test_return_field_26(nargout, out, nargin-1, in+1); + Test_return_TestPtr_26(nargout, out, nargin-1, in+1); break; case 27: - Test_return_int_27(nargout, out, nargin-1, in+1); + Test_return_bool_27(nargout, out, nargin-1, in+1); break; case 28: - Test_return_matrix1_28(nargout, out, nargin-1, in+1); + Test_return_double_28(nargout, out, nargin-1, in+1); break; case 29: - Test_return_matrix2_29(nargout, out, nargin-1, in+1); + Test_return_field_29(nargout, out, nargin-1, in+1); break; case 30: - Test_return_pair_30(nargout, out, nargin-1, in+1); + Test_return_int_30(nargout, out, nargin-1, in+1); break; case 31: - Test_return_pair_31(nargout, out, nargin-1, in+1); + Test_return_matrix1_31(nargout, out, nargin-1, in+1); break; case 32: - Test_return_ptrs_32(nargout, out, nargin-1, in+1); + Test_return_matrix2_32(nargout, out, nargin-1, in+1); break; case 33: - Test_return_size_t_33(nargout, out, nargin-1, in+1); + Test_return_pair_33(nargout, out, nargin-1, in+1); break; case 34: - Test_return_string_34(nargout, out, nargin-1, in+1); + Test_return_pair_34(nargout, out, nargin-1, in+1); break; case 35: - Test_return_vector1_35(nargout, out, nargin-1, in+1); + Test_return_ptrs_35(nargout, out, nargin-1, in+1); break; case 36: - Test_return_vector2_36(nargout, out, nargin-1, in+1); + Test_return_size_t_36(nargout, out, nargin-1, in+1); break; case 37: - Test_set_container_37(nargout, out, nargin-1, in+1); + Test_return_string_37(nargout, out, nargin-1, in+1); break; case 38: - Test_set_container_38(nargout, out, nargin-1, in+1); + Test_return_vector1_38(nargout, out, nargin-1, in+1); break; case 39: - Test_set_container_39(nargout, out, nargin-1, in+1); + Test_return_vector2_39(nargout, out, nargin-1, in+1); break; case 40: - PrimitiveRefDouble_collectorInsertAndMakeBase_40(nargout, out, nargin-1, in+1); + Test_set_container_40(nargout, out, nargin-1, in+1); break; case 41: - PrimitiveRefDouble_constructor_41(nargout, out, nargin-1, in+1); + Test_set_container_41(nargout, out, nargin-1, in+1); break; case 42: - PrimitiveRefDouble_deconstructor_42(nargout, out, nargin-1, in+1); + Test_set_container_42(nargout, out, nargin-1, in+1); break; case 43: - PrimitiveRefDouble_Brutal_43(nargout, out, nargin-1, in+1); + PrimitiveRefDouble_collectorInsertAndMakeBase_43(nargout, out, nargin-1, in+1); break; case 44: - MyVector3_collectorInsertAndMakeBase_44(nargout, out, nargin-1, in+1); + PrimitiveRefDouble_constructor_44(nargout, out, nargin-1, in+1); break; case 45: - MyVector3_constructor_45(nargout, out, nargin-1, in+1); + PrimitiveRefDouble_deconstructor_45(nargout, out, nargin-1, in+1); break; case 46: - MyVector3_deconstructor_46(nargout, out, nargin-1, in+1); + PrimitiveRefDouble_Brutal_46(nargout, out, nargin-1, in+1); break; case 47: - MyVector12_collectorInsertAndMakeBase_47(nargout, out, nargin-1, in+1); + MyVector3_collectorInsertAndMakeBase_47(nargout, out, nargin-1, in+1); break; case 48: - MyVector12_constructor_48(nargout, out, nargin-1, in+1); + MyVector3_constructor_48(nargout, out, nargin-1, in+1); break; case 49: - MyVector12_deconstructor_49(nargout, out, nargin-1, in+1); + MyVector3_deconstructor_49(nargout, out, nargin-1, in+1); break; case 50: - MultipleTemplatesIntDouble_collectorInsertAndMakeBase_50(nargout, out, nargin-1, in+1); + MyVector12_collectorInsertAndMakeBase_50(nargout, out, nargin-1, in+1); break; case 51: - MultipleTemplatesIntDouble_deconstructor_51(nargout, out, nargin-1, in+1); + MyVector12_constructor_51(nargout, out, nargin-1, in+1); break; case 52: - MultipleTemplatesIntFloat_collectorInsertAndMakeBase_52(nargout, out, nargin-1, in+1); + MyVector12_deconstructor_52(nargout, out, nargin-1, in+1); break; case 53: - MultipleTemplatesIntFloat_deconstructor_53(nargout, out, nargin-1, in+1); + MultipleTemplatesIntDouble_collectorInsertAndMakeBase_53(nargout, out, nargin-1, in+1); break; case 54: - ForwardKinematics_collectorInsertAndMakeBase_54(nargout, out, nargin-1, in+1); + MultipleTemplatesIntDouble_deconstructor_54(nargout, out, nargin-1, in+1); break; case 55: - ForwardKinematics_constructor_55(nargout, out, nargin-1, in+1); + MultipleTemplatesIntFloat_collectorInsertAndMakeBase_55(nargout, out, nargin-1, in+1); break; case 56: - ForwardKinematics_deconstructor_56(nargout, out, nargin-1, in+1); + MultipleTemplatesIntFloat_deconstructor_56(nargout, out, nargin-1, in+1); break; case 57: - TemplatedConstructor_collectorInsertAndMakeBase_57(nargout, out, nargin-1, in+1); + ForwardKinematics_collectorInsertAndMakeBase_57(nargout, out, nargin-1, in+1); break; case 58: - TemplatedConstructor_constructor_58(nargout, out, nargin-1, in+1); + ForwardKinematics_constructor_58(nargout, out, nargin-1, in+1); break; case 59: - TemplatedConstructor_constructor_59(nargout, out, nargin-1, in+1); + ForwardKinematics_constructor_59(nargout, out, nargin-1, in+1); break; case 60: - TemplatedConstructor_constructor_60(nargout, out, nargin-1, in+1); + ForwardKinematics_deconstructor_60(nargout, out, nargin-1, in+1); break; case 61: - TemplatedConstructor_constructor_61(nargout, out, nargin-1, in+1); + TemplatedConstructor_collectorInsertAndMakeBase_61(nargout, out, nargin-1, in+1); break; case 62: - TemplatedConstructor_deconstructor_62(nargout, out, nargin-1, in+1); + TemplatedConstructor_constructor_62(nargout, out, nargin-1, in+1); break; case 63: - MyFactorPosePoint2_collectorInsertAndMakeBase_63(nargout, out, nargin-1, in+1); + TemplatedConstructor_constructor_63(nargout, out, nargin-1, in+1); break; case 64: - MyFactorPosePoint2_constructor_64(nargout, out, nargin-1, in+1); + TemplatedConstructor_constructor_64(nargout, out, nargin-1, in+1); break; case 65: - MyFactorPosePoint2_deconstructor_65(nargout, out, nargin-1, in+1); + TemplatedConstructor_constructor_65(nargout, out, nargin-1, in+1); break; case 66: - MyFactorPosePoint2_print_66(nargout, out, nargin-1, in+1); + TemplatedConstructor_deconstructor_66(nargout, out, nargin-1, in+1); + break; + case 67: + MyFactorPosePoint2_collectorInsertAndMakeBase_67(nargout, out, nargin-1, in+1); + break; + case 68: + MyFactorPosePoint2_constructor_68(nargout, out, nargin-1, in+1); + break; + case 69: + MyFactorPosePoint2_deconstructor_69(nargout, out, nargin-1, in+1); + break; + case 70: + MyFactorPosePoint2_print_70(nargout, out, nargin-1, in+1); + break; + case 71: + MyFactorPosePoint2_print_71(nargout, out, nargin-1, in+1); + break; + case 72: + MyFactorPosePoint2_print_72(nargout, out, nargin-1, in+1); break; } } catch(const std::exception& e) { diff --git a/wrap/tests/expected/matlab/functions_wrapper.cpp b/wrap/tests/expected/matlab/functions_wrapper.cpp index d0f0f8ca67..17b5fb494c 100644 --- a/wrap/tests/expected/matlab/functions_wrapper.cpp +++ b/wrap/tests/expected/matlab/functions_wrapper.cpp @@ -51,7 +51,7 @@ void _functions_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_functions_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -130,43 +130,110 @@ void DefaultFuncInt_8(int nargout, mxArray *out[], int nargin, const mxArray *in int b = unwrap< int >(in[1]); DefaultFuncInt(a,b); } -void DefaultFuncString_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncInt_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncInt",nargout,nargin,1); + int a = unwrap< int >(in[0]); + DefaultFuncInt(a,0); +} +void DefaultFuncInt_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncInt",nargout,nargin,0); + DefaultFuncInt(123,0); +} +void DefaultFuncString_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("DefaultFuncString",nargout,nargin,2); string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string"); string& name = *unwrap_shared_ptr< string >(in[1], "ptr_string"); DefaultFuncString(s,name); } -void DefaultFuncObj_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncString_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncString",nargout,nargin,1); + string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string"); + DefaultFuncString(s,""); +} +void DefaultFuncString_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncString",nargout,nargin,0); + DefaultFuncString("hello",""); +} +void DefaultFuncObj_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("DefaultFuncObj",nargout,nargin,1); gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[0], "ptr_gtsamKeyFormatter"); DefaultFuncObj(keyFormatter); } -void DefaultFuncZero_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncObj_15(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncObj",nargout,nargin,0); + DefaultFuncObj(gtsam::DefaultKeyFormatter); +} +void DefaultFuncZero_16(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("DefaultFuncZero",nargout,nargin,5); int a = unwrap< int >(in[0]); int b = unwrap< int >(in[1]); double c = unwrap< double >(in[2]); - bool d = unwrap< bool >(in[3]); + int d = unwrap< int >(in[3]); bool e = unwrap< bool >(in[4]); DefaultFuncZero(a,b,c,d,e); } -void DefaultFuncVector_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncZero_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncZero",nargout,nargin,4); + int a = unwrap< int >(in[0]); + int b = unwrap< int >(in[1]); + double c = unwrap< double >(in[2]); + int d = unwrap< int >(in[3]); + DefaultFuncZero(a,b,c,d,false); +} +void DefaultFuncZero_18(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncZero",nargout,nargin,3); + int a = unwrap< int >(in[0]); + int b = unwrap< int >(in[1]); + double c = unwrap< double >(in[2]); + DefaultFuncZero(a,b,c,0,false); +} +void DefaultFuncZero_19(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncZero",nargout,nargin,2); + int a = unwrap< int >(in[0]); + int b = unwrap< int >(in[1]); + DefaultFuncZero(a,b,0.0,0,false); +} +void DefaultFuncVector_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("DefaultFuncVector",nargout,nargin,2); std::vector& i = *unwrap_shared_ptr< std::vector >(in[0], "ptr_stdvectorint"); std::vector& s = *unwrap_shared_ptr< std::vector >(in[1], "ptr_stdvectorstring"); DefaultFuncVector(i,s); } -void setPose_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncVector_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncVector",nargout,nargin,1); + std::vector& i = *unwrap_shared_ptr< std::vector >(in[0], "ptr_stdvectorint"); + DefaultFuncVector(i,{"borglab", "gtsam"}); +} +void DefaultFuncVector_22(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncVector",nargout,nargin,0); + DefaultFuncVector({1, 2, 3},{"borglab", "gtsam"}); +} +void setPose_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("setPose",nargout,nargin,1); gtsam::Pose3& pose = *unwrap_shared_ptr< gtsam::Pose3 >(in[0], "ptr_gtsamPose3"); setPose(pose); } -void TemplatedFunctionRot3_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void setPose_24(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("setPose",nargout,nargin,0); + setPose(gtsam::Pose3()); +} +void TemplatedFunctionRot3_25(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("TemplatedFunctionRot3",nargout,nargin,1); gtsam::Rot3& t = *unwrap_shared_ptr< gtsam::Rot3 >(in[0], "ptr_gtsamRot3"); @@ -212,22 +279,55 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) DefaultFuncInt_8(nargout, out, nargin-1, in+1); break; case 9: - DefaultFuncString_9(nargout, out, nargin-1, in+1); + DefaultFuncInt_9(nargout, out, nargin-1, in+1); break; case 10: - DefaultFuncObj_10(nargout, out, nargin-1, in+1); + DefaultFuncInt_10(nargout, out, nargin-1, in+1); break; case 11: - DefaultFuncZero_11(nargout, out, nargin-1, in+1); + DefaultFuncString_11(nargout, out, nargin-1, in+1); break; case 12: - DefaultFuncVector_12(nargout, out, nargin-1, in+1); + DefaultFuncString_12(nargout, out, nargin-1, in+1); break; case 13: - setPose_13(nargout, out, nargin-1, in+1); + DefaultFuncString_13(nargout, out, nargin-1, in+1); break; case 14: - TemplatedFunctionRot3_14(nargout, out, nargin-1, in+1); + DefaultFuncObj_14(nargout, out, nargin-1, in+1); + break; + case 15: + DefaultFuncObj_15(nargout, out, nargin-1, in+1); + break; + case 16: + DefaultFuncZero_16(nargout, out, nargin-1, in+1); + break; + case 17: + DefaultFuncZero_17(nargout, out, nargin-1, in+1); + break; + case 18: + DefaultFuncZero_18(nargout, out, nargin-1, in+1); + break; + case 19: + DefaultFuncZero_19(nargout, out, nargin-1, in+1); + break; + case 20: + DefaultFuncVector_20(nargout, out, nargin-1, in+1); + break; + case 21: + DefaultFuncVector_21(nargout, out, nargin-1, in+1); + break; + case 22: + DefaultFuncVector_22(nargout, out, nargin-1, in+1); + break; + case 23: + setPose_23(nargout, out, nargin-1, in+1); + break; + case 24: + setPose_24(nargout, out, nargin-1, in+1); + break; + case 25: + TemplatedFunctionRot3_25(nargout, out, nargin-1, in+1); break; } } catch(const std::exception& e) { diff --git a/wrap/tests/expected/matlab/geometry_wrapper.cpp b/wrap/tests/expected/matlab/geometry_wrapper.cpp index 81631390c9..ee1f043595 100644 --- a/wrap/tests/expected/matlab/geometry_wrapper.cpp +++ b/wrap/tests/expected/matlab/geometry_wrapper.cpp @@ -118,9 +118,9 @@ void gtsamPoint2_deconstructor_3(int nargout, mxArray *out[], int nargin, const Collector_gtsamPoint2::iterator item; item = collector_gtsamPoint2.find(self); if(item != collector_gtsamPoint2.end()) { - delete self; collector_gtsamPoint2.erase(item); } + delete self; } void gtsamPoint2_argChar_4(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -262,9 +262,9 @@ void gtsamPoint3_deconstructor_20(int nargout, mxArray *out[], int nargin, const Collector_gtsamPoint3::iterator item; item = collector_gtsamPoint3.find(self); if(item != collector_gtsamPoint3.end()) { - delete self; collector_gtsamPoint3.erase(item); } + delete self; } void gtsamPoint3_norm_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -286,14 +286,14 @@ void gtsamPoint3_string_serialize_22(int nargout, mxArray *out[], int nargin, co } void gtsamPoint3_StaticFunctionRet_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("gtsamPoint3.StaticFunctionRet",nargout,nargin,1); + checkArguments("gtsam::Point3.StaticFunctionRet",nargout,nargin,1); double z = unwrap< double >(in[0]); out[0] = wrap< Point3 >(gtsam::Point3::StaticFunctionRet(z)); } void gtsamPoint3_staticFunction_24(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("gtsamPoint3.staticFunction",nargout,nargin,0); + checkArguments("gtsam::Point3.staticFunction",nargout,nargin,0); out[0] = wrap< double >(gtsam::Point3::staticFunction()); } diff --git a/wrap/tests/expected/matlab/inheritance_wrapper.cpp b/wrap/tests/expected/matlab/inheritance_wrapper.cpp index 8e61ac8c61..0cf17eedd0 100644 --- a/wrap/tests/expected/matlab/inheritance_wrapper.cpp +++ b/wrap/tests/expected/matlab/inheritance_wrapper.cpp @@ -88,7 +88,7 @@ void _inheritance_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_inheritance_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -121,9 +121,9 @@ void MyBase_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArr Collector_MyBase::iterator item; item = collector_MyBase.find(self); if(item != collector_MyBase.end()) { - delete self; collector_MyBase.erase(item); } + delete self; } void MyTemplatePoint2_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -171,9 +171,9 @@ void MyTemplatePoint2_deconstructor_6(int nargout, mxArray *out[], int nargin, c Collector_MyTemplatePoint2::iterator item; item = collector_MyTemplatePoint2.find(self); if(item != collector_MyTemplatePoint2.end()) { - delete self; collector_MyTemplatePoint2.erase(item); } + delete self; } void MyTemplatePoint2_accept_T_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -289,7 +289,7 @@ void MyTemplatePoint2_templatedMethod_17(int nargout, mxArray *out[], int nargin void MyTemplatePoint2_Level_18(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("MyTemplatePoint2.Level",nargout,nargin,1); + checkArguments("MyTemplate.Level",nargout,nargin,1); Point2 K = unwrap< Point2 >(in[0]); out[0] = wrap_shared_ptr(boost::make_shared>(MyTemplate::Level(K)),"MyTemplatePoint2", false); } @@ -339,9 +339,9 @@ void MyTemplateMatrix_deconstructor_22(int nargout, mxArray *out[], int nargin, Collector_MyTemplateMatrix::iterator item; item = collector_MyTemplateMatrix.find(self); if(item != collector_MyTemplateMatrix.end()) { - delete self; collector_MyTemplateMatrix.erase(item); } + delete self; } void MyTemplateMatrix_accept_T_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -457,7 +457,7 @@ void MyTemplateMatrix_templatedMethod_33(int nargout, mxArray *out[], int nargin void MyTemplateMatrix_Level_34(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("MyTemplateMatrix.Level",nargout,nargin,1); + checkArguments("MyTemplate.Level",nargout,nargin,1); Matrix K = unwrap< Matrix >(in[0]); out[0] = wrap_shared_ptr(boost::make_shared>(MyTemplate::Level(K)),"MyTemplateMatrix", false); } @@ -492,9 +492,9 @@ void ForwardKinematicsFactor_deconstructor_37(int nargout, mxArray *out[], int n Collector_ForwardKinematicsFactor::iterator item; item = collector_ForwardKinematicsFactor.find(self); if(item != collector_ForwardKinematicsFactor.end()) { - delete self; collector_ForwardKinematicsFactor.erase(item); } + delete self; } diff --git a/wrap/tests/expected/matlab/multiple_files_wrapper.cpp b/wrap/tests/expected/matlab/multiple_files_wrapper.cpp index 66ab7ff73d..864ae75d62 100644 --- a/wrap/tests/expected/matlab/multiple_files_wrapper.cpp +++ b/wrap/tests/expected/matlab/multiple_files_wrapper.cpp @@ -75,7 +75,7 @@ void _multiple_files_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_multiple_files_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -110,9 +110,9 @@ void gtsamClass1_deconstructor_2(int nargout, mxArray *out[], int nargin, const Collector_gtsamClass1::iterator item; item = collector_gtsamClass1.find(self); if(item != collector_gtsamClass1.end()) { - delete self; collector_gtsamClass1.erase(item); } + delete self; } void gtsamClass2_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -143,9 +143,9 @@ void gtsamClass2_deconstructor_5(int nargout, mxArray *out[], int nargin, const Collector_gtsamClass2::iterator item; item = collector_gtsamClass2.find(self); if(item != collector_gtsamClass2.end()) { - delete self; collector_gtsamClass2.erase(item); } + delete self; } void gtsamClassA_collectorInsertAndMakeBase_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -176,9 +176,9 @@ void gtsamClassA_deconstructor_8(int nargout, mxArray *out[], int nargin, const Collector_gtsamClassA::iterator item; item = collector_gtsamClassA.find(self); if(item != collector_gtsamClassA.end()) { - delete self; collector_gtsamClassA.erase(item); } + delete self; } diff --git a/wrap/tests/expected/matlab/namespaces_wrapper.cpp b/wrap/tests/expected/matlab/namespaces_wrapper.cpp index 604ede5da5..b2fe3eed62 100644 --- a/wrap/tests/expected/matlab/namespaces_wrapper.cpp +++ b/wrap/tests/expected/matlab/namespaces_wrapper.cpp @@ -112,7 +112,7 @@ void _namespaces_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_namespaces_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -147,9 +147,9 @@ void ns1ClassA_deconstructor_2(int nargout, mxArray *out[], int nargin, const mx Collector_ns1ClassA::iterator item; item = collector_ns1ClassA.find(self); if(item != collector_ns1ClassA.end()) { - delete self; collector_ns1ClassA.erase(item); } + delete self; } void ns1ClassB_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -180,9 +180,9 @@ void ns1ClassB_deconstructor_5(int nargout, mxArray *out[], int nargin, const mx Collector_ns1ClassB::iterator item; item = collector_ns1ClassB.find(self); if(item != collector_ns1ClassB.end()) { - delete self; collector_ns1ClassB.erase(item); } + delete self; } void aGlobalFunction_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -218,9 +218,9 @@ void ns2ClassA_deconstructor_9(int nargout, mxArray *out[], int nargin, const mx Collector_ns2ClassA::iterator item; item = collector_ns2ClassA.find(self); if(item != collector_ns2ClassA.end()) { - delete self; collector_ns2ClassA.erase(item); } + delete self; } void ns2ClassA_memberFunction_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -248,7 +248,7 @@ void ns2ClassA_nsReturn_12(int nargout, mxArray *out[], int nargin, const mxArra void ns2ClassA_afunction_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("ns2ClassA.afunction",nargout,nargin,0); + checkArguments("ns2::ClassA.afunction",nargout,nargin,0); out[0] = wrap< double >(ns2::ClassA::afunction()); } @@ -280,9 +280,9 @@ void ns2ns3ClassB_deconstructor_16(int nargout, mxArray *out[], int nargin, cons Collector_ns2ns3ClassB::iterator item; item = collector_ns2ns3ClassB.find(self); if(item != collector_ns2ns3ClassB.end()) { - delete self; collector_ns2ns3ClassB.erase(item); } + delete self; } void ns2ClassC_collectorInsertAndMakeBase_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -313,9 +313,9 @@ void ns2ClassC_deconstructor_19(int nargout, mxArray *out[], int nargin, const m Collector_ns2ClassC::iterator item; item = collector_ns2ClassC.find(self); if(item != collector_ns2ClassC.end()) { - delete self; collector_ns2ClassC.erase(item); } + delete self; } void aGlobalFunction_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -364,9 +364,9 @@ void ClassD_deconstructor_25(int nargout, mxArray *out[], int nargin, const mxAr Collector_ClassD::iterator item; item = collector_ClassD.find(self); if(item != collector_ClassD.end()) { - delete self; collector_ClassD.erase(item); } + delete self; } void gtsamValues_collectorInsertAndMakeBase_26(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -409,9 +409,9 @@ void gtsamValues_deconstructor_29(int nargout, mxArray *out[], int nargin, const Collector_gtsamValues::iterator item; item = collector_gtsamValues.find(self); if(item != collector_gtsamValues.end()) { - delete self; collector_gtsamValues.erase(item); } + delete self; } void gtsamValues_insert_30(int nargout, mxArray *out[], int nargin, const mxArray *in[]) diff --git a/wrap/tests/expected/matlab/setPose.m b/wrap/tests/expected/matlab/setPose.m new file mode 100644 index 0000000000..d45cc56921 --- /dev/null +++ b/wrap/tests/expected/matlab/setPose.m @@ -0,0 +1,8 @@ +function varargout = setPose(varargin) + if length(varargin) == 1 && isa(varargin{1},'gtsam.Pose3') + functions_wrapper(23, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(24, varargin{:}); + else + error('Arguments do not match any overload of function setPose'); + end diff --git a/wrap/tests/expected/matlab/special_cases_wrapper.cpp b/wrap/tests/expected/matlab/special_cases_wrapper.cpp index 69abbf73be..c6704c20f8 100644 --- a/wrap/tests/expected/matlab/special_cases_wrapper.cpp +++ b/wrap/tests/expected/matlab/special_cases_wrapper.cpp @@ -84,7 +84,7 @@ void _special_cases_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_special_cases_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -108,9 +108,9 @@ void gtsamNonlinearFactorGraph_deconstructor_1(int nargout, mxArray *out[], int Collector_gtsamNonlinearFactorGraph::iterator item; item = collector_gtsamNonlinearFactorGraph.find(self); if(item != collector_gtsamNonlinearFactorGraph.end()) { - delete self; collector_gtsamNonlinearFactorGraph.erase(item); } + delete self; } void gtsamNonlinearFactorGraph_addPrior_2(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -140,9 +140,9 @@ void gtsamSfmTrack_deconstructor_4(int nargout, mxArray *out[], int nargin, cons Collector_gtsamSfmTrack::iterator item; item = collector_gtsamSfmTrack.find(self); if(item != collector_gtsamSfmTrack.end()) { - delete self; collector_gtsamSfmTrack.erase(item); } + delete self; } void gtsamPinholeCameraCal3Bundler_collectorInsertAndMakeBase_5(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -162,9 +162,9 @@ void gtsamPinholeCameraCal3Bundler_deconstructor_6(int nargout, mxArray *out[], Collector_gtsamPinholeCameraCal3Bundler::iterator item; item = collector_gtsamPinholeCameraCal3Bundler.find(self); if(item != collector_gtsamPinholeCameraCal3Bundler.end()) { - delete self; collector_gtsamPinholeCameraCal3Bundler.erase(item); } + delete self; } void gtsamGeneralSFMFactorCal3Bundler_collectorInsertAndMakeBase_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -184,9 +184,9 @@ void gtsamGeneralSFMFactorCal3Bundler_deconstructor_8(int nargout, mxArray *out[ Collector_gtsamGeneralSFMFactorCal3Bundler::iterator item; item = collector_gtsamGeneralSFMFactorCal3Bundler.find(self); if(item != collector_gtsamGeneralSFMFactorCal3Bundler.end()) { - delete self; collector_gtsamGeneralSFMFactorCal3Bundler.erase(item); } + delete self; } diff --git a/wrap/tests/expected/matlab/template_wrapper.cpp b/wrap/tests/expected/matlab/template_wrapper.cpp index 532bdd57a9..a0b1aaa7e2 100644 --- a/wrap/tests/expected/matlab/template_wrapper.cpp +++ b/wrap/tests/expected/matlab/template_wrapper.cpp @@ -67,7 +67,7 @@ void _template_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_template_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -138,9 +138,9 @@ void TemplatedConstructor_deconstructor_5(int nargout, mxArray *out[], int nargi Collector_TemplatedConstructor::iterator item; item = collector_TemplatedConstructor.find(self); if(item != collector_TemplatedConstructor.end()) { - delete self; collector_TemplatedConstructor.erase(item); } + delete self; } void ScopedTemplateResult_collectorInsertAndMakeBase_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -172,9 +172,9 @@ void ScopedTemplateResult_deconstructor_8(int nargout, mxArray *out[], int nargi Collector_ScopedTemplateResult::iterator item; item = collector_ScopedTemplateResult.find(self); if(item != collector_ScopedTemplateResult.end()) { - delete self; collector_ScopedTemplateResult.erase(item); } + delete self; } diff --git a/wrap/tests/expected/python/class_pybind.cpp b/wrap/tests/expected/python/class_pybind.cpp index a54d9135b2..fd53989126 100644 --- a/wrap/tests/expected/python/class_pybind.cpp +++ b/wrap/tests/expected/python/class_pybind.cpp @@ -31,6 +31,7 @@ PYBIND11_MODULE(class_py, m_) { py::class_, std::shared_ptr>>(m_, "FunDouble") .def("templatedMethodString",[](Fun* self, double d, string t){return self->templatedMethod(d, t);}, py::arg("d"), py::arg("t")) .def("multiTemplatedMethodStringSize_t",[](Fun* self, double d, string t, size_t u){return self->multiTemplatedMethod(d, t, u);}, py::arg("d"), py::arg("t"), py::arg("u")) + .def("sets",[](Fun* self){return self->sets();}) .def_static("staticMethodWithThis",[](){return Fun::staticMethodWithThis();}) .def_static("templatedStaticMethodInt",[](const int& m){return Fun::templatedStaticMethod(m);}, py::arg("m")); @@ -68,6 +69,7 @@ PYBIND11_MODULE(class_py, m_) { .def("set_container",[](Test* self, std::vector> container){ self->set_container(container);}, py::arg("container")) .def("set_container",[](Test* self, std::vector container){ self->set_container(container);}, py::arg("container")) .def("get_container",[](Test* self){return self->get_container();}) + .def("_repr_markdown_",[](Test* self, const gtsam::KeyFormatter& keyFormatter){return self->markdown(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter) .def_readwrite("model_ptr", &Test::model_ptr); py::class_, std::shared_ptr>>(m_, "PrimitiveRefDouble") diff --git a/wrap/tests/expected/python/functions_pybind.cpp b/wrap/tests/expected/python/functions_pybind.cpp index bee95ec03f..8511b5d3cd 100644 --- a/wrap/tests/expected/python/functions_pybind.cpp +++ b/wrap/tests/expected/python/functions_pybind.cpp @@ -33,7 +33,7 @@ PYBIND11_MODULE(functions_py, m_) { m_.def("DefaultFuncInt",[](int a, int b){ ::DefaultFuncInt(a, b);}, py::arg("a") = 123, py::arg("b") = 0); m_.def("DefaultFuncString",[](const string& s, const string& name){ ::DefaultFuncString(s, name);}, py::arg("s") = "hello", py::arg("name") = ""); m_.def("DefaultFuncObj",[](const gtsam::KeyFormatter& keyFormatter){ ::DefaultFuncObj(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter); - m_.def("DefaultFuncZero",[](int a, int b, double c, bool d, bool e){ ::DefaultFuncZero(a, b, c, d, e);}, py::arg("a") = 0, py::arg("b"), py::arg("c") = 0.0, py::arg("d") = false, py::arg("e")); + m_.def("DefaultFuncZero",[](int a, int b, double c, int d, bool e){ ::DefaultFuncZero(a, b, c, d, e);}, py::arg("a"), py::arg("b"), py::arg("c") = 0.0, py::arg("d") = 0, py::arg("e") = false); m_.def("DefaultFuncVector",[](const std::vector& i, const std::vector& s){ ::DefaultFuncVector(i, s);}, py::arg("i") = {1, 2, 3}, py::arg("s") = {"borglab", "gtsam"}); m_.def("setPose",[](const gtsam::Pose3& pose){ ::setPose(pose);}, py::arg("pose") = gtsam::Pose3()); m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction(t);}, py::arg("t")); diff --git a/wrap/tests/fixtures/class.i b/wrap/tests/fixtures/class.i index 40a5655064..f38c27411d 100644 --- a/wrap/tests/fixtures/class.i +++ b/wrap/tests/fixtures/class.i @@ -18,6 +18,8 @@ class Fun { template This multiTemplatedMethod(double d, T t, U u); + + std::map sets(); }; @@ -75,6 +77,10 @@ class Test { void set_container(std::vector container); std::vector get_container() const; + // special ipython method + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + // comments at the end! // even more comments at the end! diff --git a/wrap/tests/fixtures/functions.i b/wrap/tests/fixtures/functions.i index 0852a3e1e9..7f3c833328 100644 --- a/wrap/tests/fixtures/functions.i +++ b/wrap/tests/fixtures/functions.i @@ -31,7 +31,7 @@ typedef TemplatedFunction TemplatedFunctionRot3; void DefaultFuncInt(int a = 123, int b = 0); void DefaultFuncString(const string& s = "hello", const string& name = ""); void DefaultFuncObj(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); -void DefaultFuncZero(int a = 0, int b, double c = 0.0, bool d = false, bool e); +void DefaultFuncZero(int a, int b, double c = 0.0, int d = 0, bool e = false); void DefaultFuncVector(const std::vector &i = {1, 2, 3}, const std::vector &s = {"borglab", "gtsam"}); // Test for non-trivial default constructor diff --git a/wrap/tests/fixtures/geometry.i b/wrap/tests/fixtures/geometry.i index a7b900f805..e1460666c0 100644 --- a/wrap/tests/fixtures/geometry.i +++ b/wrap/tests/fixtures/geometry.i @@ -24,9 +24,6 @@ class Point2 { VectorNotEigen vectorConfusion(); void serializable() const; // Sets flag and creates export, but does not make serialization functions - - // enable pickling in python - void pickle() const; }; #include @@ -40,9 +37,6 @@ class Point3 { // enabling serialization functionality void serialize() const; // Just triggers a flag internally and removes actual function - - // enable pickling in python - void pickle() const; }; } diff --git a/wrap/tests/test_interface_parser.py b/wrap/tests/test_interface_parser.py index 49165328c9..2603e9db4a 100644 --- a/wrap/tests/test_interface_parser.py +++ b/wrap/tests/test_interface_parser.py @@ -657,8 +657,6 @@ class Global{ int globalVar; """) - # print("module: ", module) - # print(dir(module.content[0].name)) self.assertEqual(["one", "Global", "globalVar"], [x.name for x in module.content]) self.assertEqual(["two", "two_dummy", "two", "oneVar"], diff --git a/wrap/tests/test_matlab_wrapper.py b/wrap/tests/test_matlab_wrapper.py index 34940d62ef..43fedf7aa3 100644 --- a/wrap/tests/test_matlab_wrapper.py +++ b/wrap/tests/test_matlab_wrapper.py @@ -92,10 +92,19 @@ def test_functions(self): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ - 'functions_wrapper.cpp', 'aGlobalFunction.m', 'load2D.m', + 'functions_wrapper.cpp', + 'aGlobalFunction.m', + 'load2D.m', 'MultiTemplatedFunctionDoubleSize_tDouble.m', 'MultiTemplatedFunctionStringSize_tDouble.m', - 'overloadedGlobalFunction.m', 'TemplatedFunctionRot3.m' + 'overloadedGlobalFunction.m', + 'TemplatedFunctionRot3.m', + 'DefaultFuncInt.m', + 'DefaultFuncObj.m', + 'DefaultFuncString.m', + 'DefaultFuncVector.m', + 'DefaultFuncZero.m', + 'setPose.m', ] for file in files: @@ -115,10 +124,17 @@ def test_class(self): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ - 'class_wrapper.cpp', 'FunDouble.m', 'FunRange.m', - 'MultipleTemplatesIntDouble.m', 'MultipleTemplatesIntFloat.m', - 'MyFactorPosePoint2.m', 'MyVector3.m', 'MyVector12.m', - 'PrimitiveRefDouble.m', 'Test.m' + 'class_wrapper.cpp', + 'FunDouble.m', + 'FunRange.m', + 'MultipleTemplatesIntDouble.m', + 'MultipleTemplatesIntFloat.m', + 'MyFactorPosePoint2.m', + 'MyVector3.m', + 'MyVector12.m', + 'PrimitiveRefDouble.m', + 'Test.m', + 'ForwardKinematics.m', ] for file in files: @@ -137,7 +153,10 @@ def test_templates(self): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) - files = ['template_wrapper.cpp'] + files = [ + 'template_wrapper.cpp', 'ScopedTemplateResult.m', + 'TemplatedConstructor.m' + ] for file in files: actual = osp.join(self.MATLAB_ACTUAL_DIR, file) @@ -155,8 +174,11 @@ def test_inheritance(self): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ - 'inheritance_wrapper.cpp', 'MyBase.m', 'MyTemplateMatrix.m', - 'MyTemplatePoint2.m' + 'inheritance_wrapper.cpp', + 'MyBase.m', + 'MyTemplateMatrix.m', + 'MyTemplatePoint2.m', + 'ForwardKinematicsFactor.m', ] for file in files: @@ -178,10 +200,17 @@ def test_namespaces(self): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ - 'namespaces_wrapper.cpp', '+ns1/aGlobalFunction.m', - '+ns1/ClassA.m', '+ns1/ClassB.m', '+ns2/+ns3/ClassB.m', - '+ns2/aGlobalFunction.m', '+ns2/ClassA.m', '+ns2/ClassC.m', - '+ns2/overloadedGlobalFunction.m', 'ClassD.m' + 'namespaces_wrapper.cpp', + '+ns1/aGlobalFunction.m', + '+ns1/ClassA.m', + '+ns1/ClassB.m', + '+ns2/+ns3/ClassB.m', + '+ns2/aGlobalFunction.m', + '+ns2/ClassA.m', + '+ns2/ClassC.m', + '+ns2/overloadedGlobalFunction.m', + 'ClassD.m', + '+gtsam/Values.m', ] for file in files: @@ -203,8 +232,10 @@ def test_special_cases(self): files = [ 'special_cases_wrapper.cpp', - '+gtsam/PinholeCameraCal3Bundler.m', + '+gtsam/GeneralSFMFactorCal3Bundler.m', '+gtsam/NonlinearFactorGraph.m', + '+gtsam/PinholeCameraCal3Bundler.m', + '+gtsam/SfmTrack.m', ] for file in files: