Skip to content

Commit

Permalink
Format workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
doccstat committed Aug 25, 2023
1 parent 25ae79d commit 6567be4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
22 changes: 22 additions & 0 deletions src/fastcpd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,28 @@ arma::colvec cost_update_gradient(
return gradient;
}

arma::mat cost_update_hessian(
arma::mat data,
arma::colvec theta,
std::string family,
double min_prob
) {
arma::rowvec new_data = data.row(data.n_rows - 1);
arma::rowvec x = new_data.tail(new_data.n_elem - 1);
arma::mat hessian;
if (family.compare("binomial") == 0) {
double prob = 1 / (1 + exp(-arma::as_scalar(x * theta)));
hessian = (x.t() * x) * arma::as_scalar((1 - prob) * prob);
} else if (family.compare("poisson") == 0) {
double prob = exp(arma::as_scalar(x * theta));
hessian = (x.t() * x) * std::min(arma::as_scalar(prob), min_prob);
} else {
// `family` is either "lasso" or "gaussian".
hessian = x.t() * x;
}
return hessian;
}

Rcpp::List cost_update(
const arma::mat data,
arma::mat theta_hat,
Expand Down
22 changes: 0 additions & 22 deletions src/cost_update_hessian.cc → src/fastcpd_test.cc
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,27 +1,5 @@
#include "fastcpd.h"

arma::mat cost_update_hessian(
arma::mat data,
arma::colvec theta,
std::string family,
double min_prob
) {
arma::rowvec new_data = data.row(data.n_rows - 1);
arma::rowvec x = new_data.tail(new_data.n_elem - 1);
arma::mat hessian;
if (family.compare("binomial") == 0) {
double prob = 1 / (1 + exp(-arma::as_scalar(x * theta)));
hessian = (x.t() * x) * arma::as_scalar((1 - prob) * prob);
} else if (family.compare("poisson") == 0) {
double prob = exp(arma::as_scalar(x * theta));
hessian = (x.t() * x) * std::min(arma::as_scalar(prob), min_prob);
} else {
// `family` is either "lasso" or "gaussian".
hessian = x.t() * x;
}
return hessian;
}

context("cost_update_hessian Unit Test") {
test_that("binomal is correct for a two dimensional data") {
arma::colvec theta = {-0.5, 0.3};
Expand Down

0 comments on commit 6567be4

Please sign in to comment.