diff --git a/src/fastcpd.cc b/src/fastcpd.cc index 7822e273..006976f3 100755 --- a/src/fastcpd.cc +++ b/src/fastcpd.cc @@ -119,6 +119,28 @@ arma::colvec cost_update_gradient( return gradient; } +arma::mat cost_update_hessian( + arma::mat data, + arma::colvec theta, + std::string family, + double min_prob +) { + arma::rowvec new_data = data.row(data.n_rows - 1); + arma::rowvec x = new_data.tail(new_data.n_elem - 1); + arma::mat hessian; + if (family.compare("binomial") == 0) { + double prob = 1 / (1 + exp(-arma::as_scalar(x * theta))); + hessian = (x.t() * x) * arma::as_scalar((1 - prob) * prob); + } else if (family.compare("poisson") == 0) { + double prob = exp(arma::as_scalar(x * theta)); + hessian = (x.t() * x) * std::min(arma::as_scalar(prob), min_prob); + } else { + // `family` is either "lasso" or "gaussian". + hessian = x.t() * x; + } + return hessian; +} + Rcpp::List cost_update( const arma::mat data, arma::mat theta_hat, diff --git a/src/cost_update_hessian.cc b/src/fastcpd_test.cc old mode 100755 new mode 100644 similarity index 57% rename from src/cost_update_hessian.cc rename to src/fastcpd_test.cc index 43858af1..d4353aa9 --- a/src/cost_update_hessian.cc +++ b/src/fastcpd_test.cc @@ -1,27 +1,5 @@ #include "fastcpd.h" -arma::mat cost_update_hessian( - arma::mat data, - arma::colvec theta, - std::string family, - double min_prob -) { - arma::rowvec new_data = data.row(data.n_rows - 1); - arma::rowvec x = new_data.tail(new_data.n_elem - 1); - arma::mat hessian; - if (family.compare("binomial") == 0) { - double prob = 1 / (1 + exp(-arma::as_scalar(x * theta))); - hessian = (x.t() * x) * arma::as_scalar((1 - prob) * prob); - } else if (family.compare("poisson") == 0) { - double prob = exp(arma::as_scalar(x * theta)); - hessian = (x.t() * x) * std::min(arma::as_scalar(prob), min_prob); - } else { - // `family` is either "lasso" or "gaussian". - hessian = x.t() * x; - } - return hessian; -} - context("cost_update_hessian Unit Test") { test_that("binomal is correct for a two dimensional data") { arma::colvec theta = {-0.5, 0.3};