Skip to content

Commit

Permalink
Simplify custom families
Browse files Browse the repository at this point in the history
  • Loading branch information
doccstat committed Apr 25, 2024
1 parent c9aa2e0 commit 0c57b2d
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 287 deletions.
184 changes: 43 additions & 141 deletions src/fastcpd_class.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ Fastcpd::Fastcpd(
get_hessian = &Fastcpd::get_hessian_arma;
get_nll_sen = &Fastcpd::get_nll_sen_arma;
} else {
// No sequential gradient performed.
get_gradient = &Fastcpd::get_gradient_custom;
get_hessian = &Fastcpd::get_hessian_custom;
get_nll_sen = &Fastcpd::get_nll_sen_custom;
}

if (family == "lasso") {
Expand All @@ -113,13 +115,9 @@ Fastcpd::Fastcpd(
} else if (family == "mgaussian") {
get_nll_pelt = &Fastcpd::get_nll_pelt_mgaussian;
} else {
// Not a built-in family.
get_nll_pelt = &Fastcpd::get_nll_pelt_custom;
}

create_cost_function_wrapper(cost);
create_cost_gradient_wrapper(cost_gradient);
create_cost_hessian_wrapper(cost_hessian);

// TODO(doccstat): Store environment functions from R.
}

Expand Down Expand Up @@ -211,67 +209,6 @@ void Fastcpd::create_clock_in_r(const std::string name) {
}
}

void Fastcpd::create_cost_function_wrapper(Nullable<Function> cost) {
DEBUG_RCOUT(family);
if (family != "custom") {
cost_function_wrapper = std::bind( // # nocov start
&Fastcpd::get_cost_result, // # nocov end
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3,
std::placeholders::_4,
std::placeholders::_5,
std::placeholders::_6
);
} else {
fastcpd::classes::CostFunction costFunction(cost.get(), data);
cost_function_wrapper = costFunction;
}
}

void Fastcpd::create_cost_gradient_wrapper(Nullable<Function> cost_gradient) {
if (family != "custom") {
cost_gradient_wrapper = std::bind( // # nocov start
get_gradient, // # nocov end
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3
);
} else if (cost_gradient.isNotNull()) {
fastcpd::classes::CostGradient costGradient(cost_gradient.get(), data);
cost_gradient_wrapper = costGradient;
} else if (cost_gradient.isNull()) {
// `cost_gradient` can be `NULL` in the case of vanilla PELT.
} else {
// # nocov start
stop("This branch should not be reached at classes.cc: 290.");
// # nocov end
}
}

void Fastcpd::create_cost_hessian_wrapper(Nullable<Function> cost_hessian) {
if (family != "custom") {
cost_hessian_wrapper = std::bind( // # nocov start
get_hessian, // # nocov end
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3
);
} else if (cost_hessian.isNotNull()) {
fastcpd::classes::CostHessian costHessian(cost_hessian.get(), data);
cost_hessian_wrapper = costHessian;
} else if (cost_hessian.isNull()) {
// `cost_hessian` can be `NULL` in the case of vanilla PELT.
} else {
// # nocov start
stop("This branch should not be reached at classes.cc: 304.");
// # nocov end
}
}

void Fastcpd::create_gradients() {
if (vanilla_percentage == 1) return;
if (family == "binomial") {
Expand Down Expand Up @@ -305,21 +242,14 @@ void Fastcpd::create_segment_statistics() {
for (
int segment_index = 0; segment_index < segment_count; ++segment_index
) {
rowvec segment_theta;
if (family == "custom") {
segment_theta = get_optimized_cost(
segment_indices(segment_index), segment_indices(segment_index + 1) - 1
).par;
} else {
segment_theta = get_cost_result(
segment_indices(segment_index),
segment_indices(segment_index + 1) - 1,
R_NilValue,
0,
true,
R_NilValue
).par;
}
rowvec segment_theta = get_cost_result(
segment_indices(segment_index),
segment_indices(segment_index + 1) - 1,
R_NilValue,
0,
true,
R_NilValue
).par;

// Initialize the estimated coefficients for each segment to be the
// estimated coefficients in the segment.
Expand Down Expand Up @@ -419,14 +349,9 @@ List Fastcpd::get_cp_set(const colvec raw_cp_set, const double lambda) {
unsigned int residual_next_start = 0;

for (unsigned int i = 0; i < cp_loc.n_elem - 1; i++) {
CostResult cost_result;
if (family == "custom") {
cost_result = get_optimized_cost(cp_loc(i), cp_loc(i + 1) - 1);
} else {
cost_result = get_cost_result(
cp_loc(i), cp_loc(i + 1) - 1, R_NilValue, lambda, false, R_NilValue
);
}
CostResult cost_result = get_cost_result(
cp_loc(i), cp_loc(i + 1) - 1, R_NilValue, lambda, false, R_NilValue
);

cost_values(i) = cost_result.value;

Expand All @@ -437,8 +362,8 @@ List Fastcpd::get_cp_set(const colvec raw_cp_set, const double lambda) {

// Residual is only calculated for built-in families.
if (
family != "custom" &&
!(family == "mean" || family == "variance" || family == "meanvariance")
family != "custom" && family != "mean" &&
family != "variance" && family != "meanvariance"
) {
mat cost_optim_residual = cost_result.residuals;
residual.rows(
Expand Down Expand Up @@ -483,23 +408,22 @@ double Fastcpd::get_cval_pelt(
) {
double cval = 0;
CostResult cost_result;
if (family == "custom") {
cost_result = get_optimized_cost(segment_start, segment_end);
if (
(family == "binomial" || family == "poisson") &&
(warm_start && segment_end + 1 - segment_start >= 10 * p)
) {
cost_result = get_cost_result(
segment_start, segment_end, R_NilValue, lambda, false,
wrap(
segment_theta_hat.row(index_max(find(segment_indices <= segment_end))
).t())
// Or use `wrap(start.col(segment_start))` for warm start.
);
update_start(segment_start, colvec(cost_result.par));
} else {
if (warm_start && segment_end + 1 - segment_start >= 10 * p) {
cost_result = get_cost_result(
segment_start, segment_end, R_NilValue, lambda, false,
wrap(
segment_theta_hat.row(index_max(find(segment_indices <= segment_end))
).t())
// Or use `wrap(start.col(segment_start))` for warm start.
);
update_start(segment_start, colvec(cost_result.par));
} else {
cost_result = get_cost_result(
segment_start, segment_end, R_NilValue, lambda, false, R_NilValue
);
}
cost_result = get_cost_result(
segment_start, segment_end, R_NilValue, lambda, false, R_NilValue
);
}
cval = cost_result.value;

Expand Down Expand Up @@ -528,11 +452,9 @@ double Fastcpd::get_cval_sen(
colvec theta = theta_sum.col(i) / segment_length;
DEBUG_RCOUT(theta);
if (family == "custom") {
Function cost_non_null = cost.get();
SEXP cost_result = cost_non_null(
data.rows(segment_start, segment_end), theta
cval = (this->*get_nll_sen)(
segment_start, segment_end, theta, lambda
);
cval = as<double>(cost_result);
} else if (
(family != "lasso" && segment_length >= p) ||
(family == "lasso" && segment_length >= 3)
Expand Down Expand Up @@ -598,10 +520,6 @@ CostResult Fastcpd::get_optimized_cost(
return cost_result;
}

mat Fastcpd::get_theta_sum() {
return theta_sum;
}

void Fastcpd::update_cost_parameters(
const unsigned int t,
const unsigned int tau,
Expand Down Expand Up @@ -632,23 +550,12 @@ void Fastcpd::update_cost_parameters_step(
mat hessian_i = hessian.slice(i);
colvec gradient;

if (family == "custom") {
mat cost_hessian_result = cost_hessian_wrapper(
segment_start + data_start, segment_start + data_end, theta_hat.col(i)
);
hessian_i += cost_hessian_result;
colvec cost_gradient_result = cost_gradient_wrapper(
segment_start + data_start, segment_start + data_end, theta_hat.col(i)
);
gradient = cost_gradient_result;
} else {
hessian_i += (this->*get_hessian)(
segment_start + data_start, segment_start + data_end, theta_hat.col(i)
);
gradient = (this->*get_gradient)(
segment_start + data_start, segment_start + data_end, theta_hat.col(i)
);
}
hessian_i += (this->*get_hessian)(
segment_start + data_start, segment_start + data_end, theta_hat.col(i)
);
gradient = (this->*get_gradient)(
segment_start + data_start, segment_start + data_end, theta_hat.col(i)
);

// Add epsilon to the diagonal for PSD hessian
mat hessian_psd =
Expand All @@ -671,14 +578,9 @@ void Fastcpd::update_cost_parameters_step(
theta_hat.col(i) + line_search[line_search_index] * momentum;
colvec theta_upper_bound = arma::min(std::move(theta_candidate), upper);
colvec theta_projected = arma::max(std::move(theta_upper_bound), lower);
line_search_costs[line_search_index] = cost_function_wrapper(
segment_start,
segment_end,
wrap(theta_projected),
lambda,
false,
R_NilValue
).value;
line_search_costs[line_search_index] = (this->*get_nll_sen)(
segment_start, segment_end, theta_projected, lambda
);
}
}
best_learning_rate = line_search[line_search_costs.index_min()];
Expand Down
61 changes: 27 additions & 34 deletions src/fastcpd_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,39 +59,12 @@ class Fastcpd {
// Adjustment to the cost function.
const string cost_adjustment;

// Cost function. If the cost function is provided in R, this will be a
// wrapper of the R function.
function<CostResult(
const unsigned int segment_start,
const unsigned int segment_end,
Nullable<colvec> theta,
const double lambda,
const bool cv,
Nullable<colvec> start
)> cost_function_wrapper;

// `cost_gradient` is the gradient of the cost function to be used.
Nullable<Function> cost_gradient;

// Gradient of the cost function. If the cost function is provided in R, this
// will be a wrapper of the R function.
function<colvec(
const unsigned int segment_start,
const unsigned int segment_end,
const colvec& theta
)> cost_gradient_wrapper;

// `cost_hessian` is the Hessian of the cost function to be used.
Nullable<Function> cost_hessian;

// Hessian of the cost function. If the cost function is provided in R, this
// will be a wrapper of the R function.
function<mat(
const unsigned int segment_start,
const unsigned int segment_end,
const colvec& theta
)> cost_hessian_wrapper;

const bool cp_only;

// Dimension of the data.
Expand Down Expand Up @@ -224,10 +197,6 @@ class Fastcpd {
// Stop the clock and create an R object with `name`.
void create_clock_in_r(const std::string name);

void create_cost_function_wrapper(Nullable<Function> cost);
void create_cost_gradient_wrapper(Nullable<Function> cost_gradient);
void create_cost_hessian_wrapper(Nullable<Function> cost_hessian);

// Initialize \code{theta_hat}, \code{theta_sum}, and \code{hessian}.
void create_gradients();

Expand Down Expand Up @@ -296,6 +265,12 @@ class Fastcpd {
const colvec& theta
);

colvec get_gradient_custom(
const unsigned int segment_start,
const unsigned int segment_end,
const colvec& theta
);

colvec get_gradient_lm(
const unsigned int segment_start,
const unsigned int segment_end,
Expand Down Expand Up @@ -326,6 +301,12 @@ class Fastcpd {
const colvec& theta
);

mat get_hessian_custom(
const unsigned int segment_start,
const unsigned int segment_end,
const colvec& theta
);

mat get_hessian_lm(
const unsigned int segment_start,
const unsigned int segment_end,
Expand All @@ -352,6 +333,14 @@ class Fastcpd {
const Nullable<colvec>& start
);

CostResult get_nll_pelt_custom(
const unsigned int segment_start,
const unsigned int segment_end,
const double lambda,
const bool cv,
const Nullable<colvec>& start
);

CostResult get_nll_pelt_glm(
const unsigned int segment_start,
const unsigned int segment_end,
Expand Down Expand Up @@ -413,6 +402,13 @@ class Fastcpd {
double lambda
);

double get_nll_sen_custom(
const unsigned int segment_start,
const unsigned int segment_end,
colvec theta,
double lambda
);

double get_nll_sen_lm(
const unsigned int segment_start,
const unsigned int segment_end,
Expand Down Expand Up @@ -445,9 +441,6 @@ class Fastcpd {
const unsigned int segment_end
);

// Get the value of \code{theta_sum}.
mat get_theta_sum();

void update_cost_parameters(
const unsigned int t,
const unsigned int tau,
Expand Down
Loading

0 comments on commit 0c57b2d

Please sign in to comment.