From bb30689ea06b998c75786873b5e91ec8e9f9f507 Mon Sep 17 00:00:00 2001 From: talegari Date: Tue, 18 Jun 2024 18:40:26 +0530 Subject: [PATCH] added prune/reorder --- .DS_Store | Bin 8196 -> 8196 bytes DESCRIPTION | 8 +- NAMESPACE | 26 + R/augment.R | 466 --------- R/package.R | 22 +- R/rulelist.R | 1751 +++++++++++++++++++++++++++++--- R/tidy.R | 40 +- R/utils.R | 16 +- man/as_rulelist.Rd | 11 +- man/as_rulelist.data.frame.Rd | 11 +- man/augment.Rd | 10 +- man/augment.rulelist.Rd | 119 +-- man/augment_class_keys.Rd | 4 +- man/augment_class_no_keys.Rd | 4 +- man/augment_regr_keys.Rd | 4 +- man/augment_regr_no_keys.Rd | 4 +- man/calculate.Rd | 22 + man/calculate.rulelist.Rd | 141 +++ man/plot.prune_rulelist.Rd | 19 + man/plot.rulelist.Rd | 40 + man/predict.rulelist.Rd | 13 +- man/print.prune_rulelist.Rd | 16 + man/print.rulelist.Rd | 11 +- man/prune.Rd | 22 + man/prune.rulelist.Rd | 84 ++ man/reorder.Rd | 21 + man/reorder.rulelist.Rd | 53 + man/rulelist.Rd | 32 +- man/set_keys.Rd | 22 +- man/set_validation_data.Rd | 58 ++ man/tidy.C5.0.Rd | 6 +- man/tidy.Rd | 6 +- man/tidy.constparty.Rd | 13 +- man/tidy.cubist.Rd | 7 +- man/tidy.rpart.Rd | 7 +- man/to_sql_case.Rd | 4 +- tests/testthat/test-rulelist.R | 358 ++++++- 37 files changed, 2620 insertions(+), 831 deletions(-) delete mode 100644 R/augment.R create mode 100644 man/calculate.Rd create mode 100644 man/calculate.rulelist.Rd create mode 100644 man/plot.prune_rulelist.Rd create mode 100644 man/plot.rulelist.Rd create mode 100644 man/print.prune_rulelist.Rd create mode 100644 man/prune.Rd create mode 100644 man/prune.rulelist.Rd create mode 100644 man/reorder.Rd create mode 100644 man/reorder.rulelist.Rd create mode 100644 man/set_validation_data.Rd diff --git a/.DS_Store b/.DS_Store index 99063433d20788fa3495cc6188c6f8eff9d6fbc9..c2579a648a3278c9e46a9c8c2c8ca4cc4a192c57 100644 GIT binary patch delta 121 zcmZp1XmOa}I9U^hRb(qtZiYEBl06o!0;WQO9&9|dI@*(dW0MRRd*#tTSPSDP3Z yPZkt(k%ucO3ogpb$ delta 51 zcmV-30L=e{K!iY$PXQpYP`eKSAd?Ibdy{VwB$M_Ms= 1.0.3), cli (>= 3.6.2), glue (>= 1.7.0), + pheatmap (>= 1.0.12), + proxy (>= 0.4.27), Suggests: AmesHousing (>= 0.0.3), dplyr (>= 0.8), @@ -36,7 +38,7 @@ Suggests: knitr (>= 1.23), rmarkdown (>= 1.13), palmerpenguins (>= 0.1.1), -Description: Utility to convert text based summary of rule based models to a rulelist or ruleset dataframe (where each row represents a rule) with related metrics such as support, confidence and lift. Rule based models from these packages are supported: 'C5.0', 'rpart' and 'Cubist'. +Description: Extract rules as a rulelist (a class based on dataframe) along with metrics per rule such as support, confidence, lift, RMSE, IQR. Rulelists can be augmented using validation data, manipulated using standard dataframe operations, rulelists can be used to predict on unseen data, prune them based on some metrics and reoder them to optimize them for a metric. Utilities include manually creating rulesets, exporting a rulelist to SQL syntax and so on. URL: https://github.com/talegari/tidyrules BugReports: https://github.com/talegari/tidyrules/issues License: GPL-3 diff --git a/NAMESPACE b/NAMESPACE index fb12416..ae56e71 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,28 +2,48 @@ S3method(as_rulelist,data.frame) S3method(augment,rulelist) +S3method(calculate,rulelist) +S3method(plot,prune_rulelist) +S3method(plot,rulelist) S3method(predict,rulelist) +S3method(print,prune_rulelist) S3method(print,rulelist) +S3method(prune,rulelist) +S3method(reorder,rulelist) S3method(tidy,C5.0) S3method(tidy,constparty) S3method(tidy,cubist) S3method(tidy,rpart) export(as_rulelist) export(augment) +export(calculate) export(convert_rule_flavor) +export(prune) +export(reorder) export(set_keys) +export(set_validation_data) export(tidy) export(to_sql_case) importFrom(data.table,":=") importFrom(generics,augment) +importFrom(generics,calculate) +importFrom(generics,prune) importFrom(generics,tidy) +importFrom(graphics,abline) +importFrom(graphics,axis) +importFrom(graphics,legend) +importFrom(graphics,lines) importFrom(magrittr,"%>%") importFrom(rlang,"%||%") importFrom(stats,IQR) importFrom(stats,predict) +importFrom(stats,runif) importFrom(stats,weighted.mean) +importFrom(tidytable,across) importFrom(tidytable,all_of) importFrom(tidytable,arrange) +importFrom(tidytable,bind_cols) +importFrom(tidytable,bind_rows) importFrom(tidytable,distinct) importFrom(tidytable,drop_na) importFrom(tidytable,inner_join) @@ -31,10 +51,16 @@ importFrom(tidytable,left_join) importFrom(tidytable,mutate) importFrom(tidytable,n) importFrom(tidytable,nest) +importFrom(tidytable,pivot_longer) +importFrom(tidytable,pivot_wider) +importFrom(tidytable,pull) importFrom(tidytable,relocate) importFrom(tidytable,right_join) importFrom(tidytable,row_number) importFrom(tidytable,select) +importFrom(tidytable,slice) importFrom(tidytable,summarise) importFrom(tidytable,unnest) importFrom(utils,data) +importFrom(utils,head) +importFrom(utils,tail) diff --git a/R/augment.R b/R/augment.R deleted file mode 100644 index 81c5d1d..0000000 --- a/R/augment.R +++ /dev/null @@ -1,466 +0,0 @@ -################################################################################ -# This is the part of the 'tidyrules' R package hosted at -# https://github.com/talegari/tidyrules with GPL-3 license. -################################################################################ - -#' @keywords internal -#' @name augment_class_no_keys -#' @title as the name says -#' @description as the name says -#' not to be exported -augment_class_no_keys = function(x, new_data, y_name, weight = 1L, ...){ - - # raw predictions - pred_df = - x %>% - select(rule_nbr, LHS) %>% - predict(new_data, multiple = TRUE) %>% - unnest(rule_nbr) %>% - select(row_nbr, rule_nbr) - - # new_data with rule_nbr and 'keys' - new_data_with_rule_nbr = - new_data %>% - mutate(row_nbr = row_number()) %>% - mutate(weight__ = local(weight)) %>% - left_join(pred_df, by = "row_nbr") %>% - left_join(select(x, rule_nbr, RHS), by = "rule_nbr") - - prevalence_df = - new_data_with_rule_nbr %>% - summarise(prevalence_0 = sum(weight__, na.rm = TRUE), - .by = eval(y_name) - ) %>% - drop_na(prevalence_0) %>% - mutate(prevalence = prevalence_0 / sum(prevalence_0)) %>% - select(all_of(c(eval(y_name), "prevalence"))) - - aggregatees_df = - new_data_with_rule_nbr %>% - # bring 'prevalence' column - left_join(prevalence_df,by = eval(y_name)) %>% - summarise( - support = sum(weight__, na.rm = TRUE), - confidence = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE), - weight__, - na.rm = TRUE - ), - lift = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE), - weight__, - na.rm = TRUE - ) / prevalence[1], - .by = rule_nbr - ) %>% - nest(.by = rule_nbr, .key = "augmented_stats") - - # output has all columns of 'tidy' along with 'augment_stats' - res = - x %>% - left_join(aggregatees_df, by = c("rule_nbr")) %>% - arrange(rule_nbr) - - return(res) -} - -#' @keywords internal -#' @name augment_class_keys -#' @title as the name says -#' @description as the name says -#' not to be exported -augment_class_keys = function(x, new_data, y_name, weight = 1L, ...){ - - keys = attr(x, "keys") - - # raw predictions - # columns: row_nbr, rule_nbr, `keys` - pred_df = - x %>% - select(all_of(c("rule_nbr", "LHS", keys))) %>% - predict(new_data, multi = TRUE) %>% - unnest(rule_nbr) # columns: row_nbr, rule_nbr, `keys` - - # new_data with rule_nbr and 'keys' - # columns: row_nbr, rule_nbr, `keys`, RHS, columns of new_data - new_data_with_rule_nbr = - # new_data with row_nbr and weight__ columns - new_data %>% - mutate(row_nbr = row_number()) %>% - mutate(weight__ = weight) %>% - # bring rule_nbr, `keys` (multiple rows per row_nbr might get created) - inner_join(pred_df, by = "row_nbr") %>% - # bring RHS column from tidy object - inner_join(select(x, all_of(c("rule_nbr", keys, "RHS"))), - by = c(keys, "rule_nbr") - ) - - # prevalence per 'keys' - prevalence_df = - new_data_with_rule_nbr %>% - summarise(prevalence_0 = sum(weight__, na.rm = TRUE), - .by = c(keys, eval(y_name)) - ) %>% - drop_na(prevalence_0) %>% - mutate(prevalence = prevalence_0 / sum(prevalence_0, na.rm = TRUE), - .by = c(keys) - ) %>% - select(all_of(c(keys, eval(y_name), "prevalence"))) - - # add aggregates at rule_nbr and 'keys' level - aggregatees_df = - new_data_with_rule_nbr %>% - left_join(prevalence_df, by = c(keys, eval(y_name))) %>% - summarise( - support = sum(weight__, na.rm = TRUE), - confidence = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE), - weight__, - na.rm = TRUE - ), - lift = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE), - weight__, - na.rm = TRUE - ) / prevalence[1], - ..., - .by = c(keys, "rule_nbr") - ) %>% - nest(.by = c("rule_nbr", keys), .key = "augmented_stats") - - # output has all columns of 'tidy' along with 'augment_stats' - res = - x %>% - left_join(aggregatees_df, by = c("rule_nbr", keys)) %>% - arrange(!!!rlang::syms(c(keys, "rule_nbr"))) %>% - relocate(all_of(c("rule_nbr", keys))) - - return(res) -} - -#' @keywords internal -#' @name augment_regr_no_keys -#' @title as the name says -#' @description as the name says -#' not to be exported -augment_regr_no_keys = function(x, new_data, y_name, weight = 1L, ...){ - - # raw predictions - pred_df = - x %>% - select(rule_nbr, LHS) %>% - predict(new_data, multiple = TRUE) %>% - unnest(rule_nbr) %>% - select(row_nbr, rule_nbr) - - # new_data with rule_nbr and 'keys' - new_data_with_rule_nbr = - new_data %>% - mutate(row_nbr = row_number()) %>% - mutate(weight__ = local(weight)) %>% - left_join(pred_df, by = "row_nbr") %>% - left_join(select(x, rule_nbr, RHS), by = "rule_nbr") - - if (is.character(x$RHS)) { - new_data_with_rule_nbr = - new_data_with_rule_nbr %>% - nest(.by = c("RHS", "row_nbr")) %>% - mutate(RHS = purrr::map2_dbl(RHS, - data, - ~ eval(parse(text = .x), envir = .y) - ) - ) %>% - unnest(data) - } - - aggregatees_df = - new_data_with_rule_nbr %>% - summarise( - support = sum(weight__, na.rm = TRUE), - IQR = DescTools::IQRw(.data[[y_name]], weight__, na.rm = TRUE), - RMSE = MetricsWeighted::rmse(actual = .data[[y_name]], - predicted = RHS, - w = weight__, - na.rm = TRUE - ), - .by = rule_nbr - ) %>% - nest(.by = rule_nbr, .key = "augmented_stats") - - # output has all columns of 'tidy' along with 'augment_stats' - res = - x %>% - left_join(aggregatees_df, by = c("rule_nbr")) %>% - arrange(rule_nbr) - - return(res) -} - -#' @keywords internal -#' @name augment_regr_keys -#' @title as the name says -#' @description as the name says -#' not to be exported -augment_regr_keys = function(x, new_data, y_name, weight = 1L, ...){ - - keys = attr(x, "keys") - - # raw predictions - # columns: row_nbr, rule_nbr, `keys` - pred_df = - x %>% - select(all_of(c("rule_nbr", "LHS", keys))) %>% - predict(new_data, multi = TRUE) %>% - unnest(rule_nbr) # columns: row_nbr, rule_nbr, `keys` - - # new_data with rule_nbr and 'keys' - # columns: row_nbr, rule_nbr, `keys`, RHS, columns of new_data - new_data_with_rule_nbr = - # new_data with row_nbr and weight__ columns - new_data %>% - mutate(row_nbr = row_number()) %>% - mutate(weight__ = weight) %>% - # bring rule_nbr, `keys` (multiple rows per row_nbr might get created) - inner_join(pred_df, by = "row_nbr") %>% - # bring RHS column from tidy object - inner_join(select(x, all_of(c("rule_nbr", keys, "RHS"))), - by = c(keys, "rule_nbr") - ) - - if (is.character(x$RHS)) { - new_data_with_rule_nbr = - new_data_with_rule_nbr %>% - nest(.by = c("RHS", keys, "row_nbr")) %>% - mutate(RHS = purrr::map2_dbl(RHS, - data, - ~ eval(parse(text = .x), envir = .y) - ) - ) %>% - unnest(data) - } - - aggregatees_df = - new_data_with_rule_nbr %>% - summarise( - support = sum(weight__, na.rm = TRUE), - IQR = DescTools::IQRw(.data[[y_name]], weight__, na.rm = TRUE), - RMSE = MetricsWeighted::rmse(actual = .data[[y_name]], - predicted = RHS, - w = weight__, - na.rm = TRUE - ), - .by = c(keys, "rule_nbr") - ) %>% - nest(.by = c("rule_nbr", keys), .key = "augmented_stats") - - # output has all columns of 'tidy' along with 'augment_stats' - res = - x %>% - left_join(aggregatees_df, by = c("rule_nbr", keys)) %>% - arrange(!!!rlang::syms(c(keys, "rule_nbr"))) %>% - relocate(all_of(c("rule_nbr", keys))) - - return(res) -} - -#' @name augment -#' @title `augment` is re-export of [generics::augment] from -#' [tidyrules][package_tidyrules] package -#' @description See [augment.rulelist] -#' @param x A [rulelist] -#' @param ... For methods to use -#' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist] -#' @importFrom generics augment -#' @family Augment -#' @export -generics::augment - -#' @name augment.rulelist -#' @title Augment a [rulelist] -#' @description `augment` outputs a [rulelist] with an additional column named -#' `augmented_stats` based on summary statistics calculated using `new_data`. -#' @param x A [rulelist] -#' @param new_data (dataframe) with column named `y_name` present -#' @param y_name (string) Column name representing the dependent variable -#' @param weight (numeric, default: 1) Positive weight vector with length equal -#' to one or number of rows of 'new_data' -#' @param ... (expressions) To be send to [tidytable::summarise] for custom -#' aggregations. See examples. -#' @returns A [rulelist] with a new dataframe-column named `augmented_stats`. -#' @details The dataframe-column `augmented_stats` will have these columns -#' corresponding to the `estimation_type`: -#' -#' - For `regression`: `support`, `IQR`, `RMSE` -#' - For `classification`: `support`, `confidence`, `lift` -#' -#' All these metrics are computed in a weighted sense. Arg `weight` is 1 by -#' default. -#' -#' @examples -#' # Examples for augment ------------------------------------------------------ -#' library("magrittr") -#' -#' # C5 ---- -#' att = modeldata::attrition -#' set.seed(100) -#' train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) -#' -#' model_c5 = C50::C5.0(Attrition ~., data = att[train_index, ], rules = TRUE) -#' tidy_c5 = tidy(model_c5) -#' tidy_c5 -#' -#' # augment -#' augmented = augment(tidy_c5, new_data = att[!train_index, ], y_name = "Attrition") -#' -#' augmented %>% -#' tidytable::unnest(augmented_stats, names_sep = "__") %>% -#' tidytable::glimpse() -#' -#' # augment with custom aggregator -#' augmented = -#' augment(tidy_c5, -#' new_data = att[!train_index, ], -#' y_name = "Attrition", -#' output_counts = list(table(Attrition)) -#' ) -#' -#' augmented %>% -#' tidytable::unnest(augmented_stats, names_sep = "__") %>% -#' tidytable::glimpse() -#' -#' # rpart ---- -#' set.seed(100) -#' train_index = sample(c(TRUE, FALSE), nrow(iris), replace = TRUE) -#' -#' model_class_rpart = rpart::rpart(Species ~ ., data = iris[train_index, ]) -#' tidy_class_rpart = tidy(model_class_rpart) -#' tidy_class_rpart -#' -#' model_regr_rpart = rpart::rpart(Sepal.Length ~ ., data = iris[train_index, ]) -#' tidy_regr_rpart = tidy(model_regr_rpart) -#' tidy_regr_rpart -#' -#' #' augment (classification case) -#' augmented = -#' augment(tidy_class_rpart, -#' new_data = iris[!train_index, ], -#' y_name = "Species" -#' ) -#' augmented -#' -#' augmented %>% -#' tidytable::unnest(augmented_stats, names_sep = "__") %>% -#' tidytable::glimpse() -#' -#' #' augment (regression case) -#' augmented = -#' augment(tidy_regr_rpart, -#' new_data = iris[!train_index, ], -#' y_name = "Sepal.Length" -#' ) -#' augmented -#' -#' augmented %>% -#' tidytable::unnest(augmented_stats, names_sep = "__") %>% -#' tidytable::glimpse() -#' -#' # party ---- -#' pen = palmerpenguins::penguins -#' set.seed(100) -#' train_index = sample(c(TRUE, FALSE), nrow(pen), replace = TRUE) -#' -#' model_class_party = partykit::ctree(species ~ ., data = pen[train_index, ]) -#' tidy_class_party = tidy(model_class_party) -#' tidy_class_party -#' -#' model_regr_party = partykit::ctree(bill_length_mm ~ ., data = pen[train_index, ]) -#' tidy_regr_party = tidy(model_regr_party) -#' tidy_regr_party -#' -#' #' augment (classification case) -#' augmented = -#' augment(tidy_class_party, -#' new_data = pen[!train_index, ], -#' y_name = "species" -#' ) -#' augmented -#' -#' augmented %>% -#' tidytable::unnest(augmented_stats, names_sep = "__") %>% -#' tidytable::glimpse() -#' -#' #' augment (regression case) -#' augmented = -#' augment(tidy_regr_party, -#' new_data = tidytable::drop_na(pen[!train_index, ], bill_length_mm), -#' y_name = "bill_length_mm" -#' ) -#' augmented -#' -#' augmented %>% -#' tidytable::unnest(augmented_stats, names_sep = "__") %>% -#' tidytable::glimpse() -#' -#' # cubist ---- -#' att = modeldata::attrition -#' set.seed(100) -#' train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) -#' cols_att = setdiff(colnames(att), c("MonthlyIncome", "Attrition")) -#' -#' model_cubist = Cubist::cubist(x = att[train_index, cols_att], -#' y = att[train_index, "MonthlyIncome"] -#' ) -#' -#' tidy_cubist = tidy(model_cubist) -#' tidy_cubist -#' -#' augmented = -#' augment(tidy_cubist, -#' new_data = att[!train_index, ], -#' y_name = "MonthlyIncome" -#' ) -#' augmented -#' -#' augmented %>% -#' tidytable::unnest(augmented_stats, names_sep = "__") %>% -#' tidytable::glimpse() -#' -#' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist] -#' @family Augment -#' @export -augment.rulelist = function(x, new_data, y_name, weight = 1L, ...){ - - checkmate::assert_string(y_name) - checkmate::assert_data_frame(new_data) - checkmate::assert_true(y_name %in% colnames(new_data)) - checkmate::assert_numeric(weight, - lower = 1e-8, - finite = TRUE, - any.missing = FALSE - ) - checkmate::assert_true(length(weight) %in% c(1, nrow(new_data))) - checkmate::assert_false(anyNA(new_data[[y_name]])) - - - estimation_type = attr(x, "estimation_type") - keys = attr(x, "keys") - - if (is.null(keys)) { - if (estimation_type == "classification"){ - res = augment_class_no_keys(x, new_data, y_name, weight, ...) - } else if (estimation_type == "regression") { - res = augment_regr_no_keys(x, new_data, y_name, weight, ...) - } else { - rlang::abort("unknown 'estimation_type'") - } - - } else { - - if (estimation_type == "classification"){ - res = augment_class_keys(x, new_data, y_name, weight, ...) - } else if (estimation_type == "regression") { - res = augment_regr_keys(x, new_data, y_name, weight, ...) - } else { - rlang::abort("unknown 'estimation_type'") - } - } - - attr(res, "data") = new_data - return(res) -} diff --git a/R/package.R b/R/package.R index 229845f..0f2060c 100644 --- a/R/package.R +++ b/R/package.R @@ -36,6 +36,21 @@ #' @importFrom tidytable row_number #' @importFrom tidytable drop_na #' @importFrom tidytable relocate +#' @importFrom tidytable bind_rows +#' @importFrom tidytable pull +#' @importFrom tidytable slice +#' @importFrom tidytable pivot_wider +#' @importFrom tidytable pivot_longer +#' @importFrom tidytable bind_cols +#' @importFrom tidytable across +#' @importFrom graphics abline +#' @importFrom graphics axis +#' @importFrom graphics legend +#' @importFrom graphics lines +#' @importFrom stats runif +#' @importFrom utils head +#' @importFrom utils tail +#' "_PACKAGE" list.rules.party = getFromNamespace(".list.rules.party", "partykit") @@ -72,7 +87,12 @@ utils::globalVariables(c(".", ".data", "rn_df", "trial_nbr", - "error" + "error", + "data__", + "rn_df__", + "hit", + "priority", + "value" ) ) diff --git a/R/rulelist.R b/R/rulelist.R index e5a14e0..1fb3f63 100644 --- a/R/rulelist.R +++ b/R/rulelist.R @@ -1,16 +1,13 @@ -################################################################################ +#******************************************************************************* # This is the part of the 'tidyrules' R package hosted at # https://github.com/talegari/tidyrules with GPL-3 license. -################################################################################ +#******************************************************************************* -################################################################################ -#### rulelist documentation -################################################################################ +#### rulelist documentation ---- #' @name rulelist #' @title Rulelist -#' @description -#' ## Structure +#' @description ## Structure #' #' A `rulelist` is ordered list of rules stored as a dataframe. Each row, #' specifies a rule (LHS), expected outcome (RHS) and some other details. @@ -32,13 +29,13 @@ #' | 4|( island %in% c('Dream', 'Torgersen') ) & ( bill_length_mm <= 44.1 ) |Adelie | 111| 0.9459459| 2.140825| #' ``` #' -#' ## Create a rulelist +#' ## Create a rulelist #' #' A `rulelist` can be created using [tidy()] on some supported model fits #' (run: `utils::methods(tidy)`). It can also be created manually from a #' existing dataframe using [as_rulelist][as_rulelist.data.frame]. #' -#' ## Keys and attributes +#' ## Keys and attributes #' #' Columns identified as 'keys' along with `rule_nbr` form a unique #' combination @@ -54,16 +51,32 @@ #' - `keys`: (character vector)Names of the column that forms a key. #' - `model_type`: (string) Name of the model #' -#' ## Methods for rulelist +#' ## Set Validation data +#' +#' This helps a few methods like [augment], [calculate], [prune], [reorder] +#' require few additional attributes which can be set using +#' [set_validation_data]. +#' +#' ## Methods for rulelist #' #' 1. [Predict][predict.rulelist]: Given a dataframe (possibly without a #' dependent variable column aka 'test data'), predicts the first rule (as #' ordered in the rulelist) per 'keys' that is applicable for each row. When #' `multiple = TRUE`, returns all rules applicable for a row (per key). #' -#' 2. [Augment][augment.rulelist]: Given a dataframe (with dependent variable -#' column, aka validation data), creates summary statistics per rule and -#' returns a rulelist with a new dataframe-column. +#' 2. [Augment][augment.rulelist]: Outputs summary statistics per rule over +#' validation data and returns a rulelist with a new dataframe-column. +#' +#' 3. [Calculate][calculate.rulelist]: Computes metrics for a rulelist in a +#' cumulative manner such as `cumulative_coverage`, `cumulative_overlap`, +#' `cumulative_accuracy`. +#' +#' 4. [Prune][prune.rulelist]: Suggests pruning a rulelist such that some +#' expectation are met (based on metrics). Example: cumulative_coverage of 80% +#' can be met with a first few rules. +#' +#' 5. [Reorder][reorder.rulelist]: Reorders a rulelist in order to maximize a +#' metric. #' #' ## Manipulating a rulelist #' @@ -75,19 +88,272 @@ #' ## Utilities for a rulelist #' #' 1. [as_rulelist][as_rulelist.data.frame]: Create a `rulelist` from a -#' dataframe with some mandatory columns. 2. [set_keys]: Set or Unset 'keys' -#' of a `rulelist`. 3. [to_sql_case]: Outputs a SQL case statement for a -#' `rulelist`. 4. [convert_rule_flavor]: Converts `R`-parsable rule strings to -#' python/SQL parsable rule strings. +#' dataframe with some mandatory columns. +#' +#' 2. [set_keys]: Set or Unset 'keys' of a `rulelist`. +#' +#' 3. [to_sql_case]: Outputs a SQL case statement for a `rulelist`. +#' +#' 4. [convert_rule_flavor]: Converts `R`-parsable rule strings to python/SQL +#' parsable rule strings. #' #' @seealso [rulelist], [tidy], [augment][augment.rulelist], -#' [predict][predict.rulelist] +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] identity # just a placeholder for 'rulelist' documentation, not exported -################################################################################ -#### print -################################################################################ + +#### validate ---- + +#' @keywords internal +#' raises a meaningful error if there is a problem +#' else returns TRUE invisibly +#' not to be exported +validate_rulelist = function(x){ + + checkmate::assert_class(x, "rulelist") + + keys = attr(x, "keys") + estimation_type = attr(x, "estimation_type") + model_type = attr(x, "model_type") + validation_data = attr(x, "validation_data") + y_name = attr(x, "y_name") + weight = attr(x, "weight") + + x = as.data.frame(x) + + # check on basic columns and 'key' columns + basic_cols = c("rule_nbr", "LHS", "RHS") + if (is.null(keys)) { + checkmate::assert_subset(basic_cols, colnames(x)) + # create key combo + key_combo_df = distinct(x, rule_nbr) + + } else { + + # keys should be different from basic cols + if (length(intersect(keys, basic_cols)) > 0) { + rlang::abort("keys should not one among: 'rule_nbr', 'LHS', 'RHS'") + } + # expected columns exist exist + checkmate::assert_subset(c(basic_cols, keys), colnames(x)) + # create key combo + key_combo_df = distinct(select(x, all_of(c("rule_nbr", keys)))) + } + + if (nrow(key_combo_df) != nrow(x)) { + rlang::abort("`rule_nbr` and 'keys' (if any) together are not unique") + } + checkmate::assert_false(anyNA(key_combo_df)) + + checkmate::assert_vector(x$rule_nbr, any.missing = FALSE) + checkmate::assert_character(x$LHS, any.missing = FALSE) + checkmate::assert_vector(x$RHS, any.missing = FALSE) + + checkmate::assert_string(model_type, null.ok = TRUE) + + checkmate::assert_string(estimation_type) + checkmate::assert_subset(estimation_type, c("classification", "regression")) + + if (!is.null(validation_data)) { + + checkmate::assert_data_frame(validation_data) + + checkmate::assert_string(y_name) + checkmate::assert_true(y_name %in% colnames(validation_data)) + + if (estimation_type == "classification") { + checkmate::assert_factor(validation_data[[y_name]], any.missing = FALSE) + } else if (estimation_type == "regression") { + checkmate::assert_numeric(validation_data[[y_name]], any.missing = FALSE) + } + + checkmate::assert_numeric(weight, + lower = 0, + finite = TRUE, + any.missing = FALSE + ) + checkmate::assert_true(length(weight) %in% c(1, nrow(validation_data))) + } + + return(invisible(TRUE)) + +} + +#### as_rulelist ---- + +#' @name as_rulelist +#' @title as_rulelist generic from [tidyrules][package_tidyrules] package +#' @description as_rulelist generic +#' @param x object to be coerced to a [rulelist] +#' @param ... for methods to use +#' @return A [rulelist] +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @export +as_rulelist = function(x, ...){ + UseMethod("as_rulelist", x) +} + +#' @name as_rulelist.data.frame +#' @title as_rulelist method for a data.frame +#' @description Convert a set of rules in a dataframe to a [rulelist] +#' @param x dataframe to be coerced to a [rulelist] +#' @param keys (character vector, default: NULL) column names which form the key +#' @param model_type (string, default: NULL) Name of the model which generated +#' the rules +#' @param estimation_type (string) One among: 'regression', +#' 'classification' +#' @param ... currently unused +#' @return [rulelist] object +#' @details Input dataframe should contain these columns: `rule_nbr`, `LHS`, +#' `RHS`. Providing other inputs helps augment better. +#' @examples +#' rules_df = tidytable::tidytable(rule_nbr = 1:2, +#' LHS = c("var_1 > 50", "var_2 < 30"), +#' RHS = c(2, 1) +#' ) +#' as_rulelist(rules_df, estimation_type = "regression") +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @export +as_rulelist.data.frame = function(x, + keys = NULL, + model_type = NULL, + estimation_type, + ... + ){ + + # set class and attributes + res = rlang::duplicate(x) + + class(res) = c("rulelist", class(res)) + attr(res, "keys") = keys + attr(res, "estimation_type") = estimation_type + attr(res, "model_type") = model_type + + # validate rulelist + validate_rulelist(res) + + return(res) +} + +#### set_keys ---- + +#' @name set_keys +#' @title Set keys for a [rulelist] +#' @description 'keys' are a set of column(s) which identify a group of rules in +#' a [rulelist]. Methods like [predict][predict.rulelist], +#' [augment][augment.rulelist] produce output per key combination. +#' +#' @param x A [rulelist] +#' @param keys (character vector or NULL) +#' @param reset (flag) Whether to reset the keys to sequential numbers startign +#' with 1 when `keys` is set to NULL +#' +#' @returns A [rulelist] object +#' +#' @details A new [rulelist] is returned with attr `keys` is modified. The input +#' [rulelist] object is unaltered. +#' +#' @examples +#' model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) +#' tidy_c5 = tidy(model_c5) +#' tidy_c5 # keys are: "trial_nbr" +#' +#' tidy_c5[["rule_nbr"]] = 1:nrow(tidy_c5) +#' new_tidy_c5 = set_keys(tidy_c5, NULL) # remove all keys +#' new_tidy_c5 +#' +#' new_2_tidy_c5 = set_keys(new_tidy_c5, "trial_nbr") # set "trial_nbr" as key +#' new_2_tidy_c5 +#' +#' # Note that `tidy_c5` and `new_tidy_c5` are not altered. +#' tidy_c5 +#' new_tidy_c5 +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @family Core Rulelist Utility +#' @export +set_keys = function(x, keys, reset = FALSE){ + + validate_rulelist(x) + + res = rlang::duplicate(x) + attr(res, "keys") = keys + + if (is.null(keys) && reset){ + res[["rule_nbr"]] = 1:nrow(res) + cli::cli_alert_info("`set_keys` has reset `rule_nbr` column to sequential integers starting with 1.") + } + + validate_rulelist(res) + + return(res) +} + +#### set_validation_data ---- + +#' @name set_validation_data +#' @title Add `validation_data` to a [rulelist] +#' @description Returns a [rulelist] with three new attributes set: +#' `validation_data`, `y_name` and `weight`. Methods such as +#' [augment][augment.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder] require this to be set. +#' +#' @param x A [rulelist] +#' @param validation_data (dataframe) Data to used for computing some metrics. +#' It is expected to contain `y_name` column. +#' @param y_name (string) Name of the dependent variable column. +#' @param weight (non-negative numeric vector, default: 1) Weight per +#' observation/row of `validation_data`. This is expected to have same length +#' as the number of rows in `validation_data`. Only exception is when it is a +#' single positive number, which means that all rows have equal weight. +#' +#' @returns A [rulelist] with some extra attributes set. +#' +#' @examples +#' att = modeldata::attrition +#' set.seed(100) +#' index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) +#' model_c5 = C50::C5.0(Attrition ~., data = att[index, ], rules = TRUE) +#' +#' tidy_c5 = tidy(model_c5) +#' tidy_c5 +#' +#' tidy_c5_2 = set_validation_data(tidy_c5, +#' validation_data = att[!index, ], +#' y_name = "Attrition", +#' weight = 1 # default +#' ) +#' tidy_c5_2 +#' tidy_c5 # not altered +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @family Core Rulelist Utility +#' @export +set_validation_data = function(x, validation_data, y_name, weight = 1){ + + res = rlang::duplicate(x) + + checkmate::assert_data_frame(validation_data) + attr(res, "validation_data") = data.table::as.data.table(validation_data) + attr(res, "y_name") = y_name + attr(res, "weight") = weight + + validate_rulelist(x) + + return(res) +} + +#### print ---- #' @name print.rulelist #' @title Print method for [rulelist] class @@ -95,16 +361,24 @@ identity # just a placeholder for 'rulelist' documentation, not exported #' @param x A [rulelist] object #' @param ... Passed to `tidytable::print` #' @return input [rulelist] (invisibly) -#' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist] -#' @family Core Rulelist Utility +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @export print.rulelist = function(x, ...){ - keys = attr(x, "keys") + + validate_rulelist(x) + + keys = attr(x, "keys") + estimation_type = attr(x, "estimation_type") + model_type = attr(x, "model_type") + validation_data = attr(x, "validation_data") cli::cli_rule(left = "Rulelist") + cli::cli_text("") if (is.null(keys)) { - cli::cli_alert_info("{.emph keys}: {.strong NULL}") + cli::cli_alert_info("{.emph Keys}: {.strong NULL}") } else { cli::cli_alert_info("{.emph keys}: {.val {keys}}") n_combo = nrow(distinct(select(x, all_of(keys)))) @@ -113,31 +387,134 @@ print.rulelist = function(x, ...){ cli::cli_alert_info("{.emph Number of rules}: {.val {nrow(x)}}") - model_type = attr(x, 'model_type') if (is.null(model_type)){ cli::cli_alert_info("{.emph Model type}: {.strong NULL}") } else { cli::cli_alert_info("{.emph Model type}: {.val {model_type}}") } - estimation_type = attr(x, 'estimation_type') if (is.null(estimation_type)){ - cli::cli_alert_info("{.emph estimation type}: {.strong NULL}") + cli::cli_alert_info("{.emph Estimation type}: {.strong NULL}") + } else { + cli::cli_alert_info("{.emph Estimation type}: {.val {estimation_type}}") + } + + if (is.null(validation_data)){ + cli::cli_alert_warning("{.emph Is validation data set}: {.strong FALSE}") } else { - cli::cli_alert_info("{.emph estimation type}: {.val {estimation_type}}") + cli::cli_alert_success("{.emph Is validation data set}: {.strong TRUE}") } + cli::cli_text("") class(x) = setdiff(class(x), "rulelist") print(x, ...) + cli::cli_rule() class(x) = c("rulelist", class(x)) return(invisible(x)) } -################################################################################ -#### predict -################################################################################ +#### plot ---- + +#' @name plot.rulelist +#' @title Plot method for rulelist +#' @description Plots a heatmap with `rule_nbr`'s on x-side and clusters of +#' `row_nbr`'s on y-side of a binary matrix with 1 if a rule is applicable for +#' a row. +#' +#' @param x A [rulelist] +#' @param thres_cluster_rows (positive integer) Maximum number of rows beyond +#' which a x-side dendrogram is not computed +#' @param dist_metric (string or function, default: "jaccard") Distance metric +#' for y-side (`rule_nbr`) passed to `method` argument of [proxy::dist] +#' @param ... Arguments to be passed to [pheatmap::pheatmap] +#' +#' @details Number of clusters is set to min(number of unique rows in the +#' row_nbr X rule_nbr matrix and thres_cluster_rows) +#' +#' @examples +#' library("magrittr") +#' att = modeldata::attrition +#' tidy_c5 = +#' C50::C5.0(Attrition ~., data = att, rules = TRUE) %>% +#' tidy() %>% +#' set_validation_data(att, "Attrition") %>% +#' set_keys(NULL) +#' +#' plot(tidy_c5) +#' +#' @export +plot.rulelist = function(x, + thres_cluster_rows = 1e3, + dist_metric = "jaccard", + ... + ){ + + validate_rulelist(x) + + checkmate::assert_false(is.null(attr(x, "validation_data"))) + checkmate::assert_true(is.null(attr(x, "keys"))) + + validation_data = attr(x, "validation_data") + y_name = attr(x, "y_name") + estimation_type = attr(x, "estimation_type") + + df_plot = + x %>% + predict(validation_data, multiple = TRUE) %>% + unnest(rule_nbr) %>% + drop_na(rule_nbr) %>% + mutate(value = 1L) %>% + left_join(validation_data %>% + select(all_of(y_name)) %>% + mutate(row_nbr = row_number()), + by = "row_nbr" + ) %>% + mutate(rule_nbr = as.character(rule_nbr)) %>% + tidytable::pivot_wider(names_from = rule_nbr, + values_from = value, + values_fill = 0L + ) %>% + arrange(row_nbr) + + mat_obj = + df_plot %>% + select(-all_of(c(y_name, "row_nbr"))) %>% + as.matrix() + + rownames(mat_obj) = 1:nrow(mat_obj) + + row_df = data.frame("y_name" = df_plot[[y_name]]) + colnames(row_df) = y_name + rownames(row_df) = rownames(mat_obj) + + if (estimation_type == "regression" && is.character(x$RHS)) { + col_df = NA + } else { + col_df = select(as.data.frame(x), rule_nbr, RHS) + colnames(col_df) = c("rule_nbr", y_name) + rownames(col_df) = col_df[["rule_nbr"]] + col_df[["rule_nbr"]] = NULL + } + n_unique_rows = nrow(unique(mat_obj)) + n_clusters = min(n_unique_rows, thres_cluster_rows) + + pheatmap::pheatmap( + mat_obj, + kmeans_k = min(n_unique_rows, n_clusters), + clustering_distance_cols = proxy::dist(mat_obj, + method = dist_metric, + by_rows = FALSE + ), + annotation_row = row_df, + annotation_col = col_df, + ... + ) +} + + +#### predict ---- #' @keywords internal #' @name predict_all_nokeys_rulelist @@ -153,6 +530,7 @@ predict_all_nokeys_rulelist = function(rulelist, new_data){ new_data2 = rlang::duplicate(new_data) new_data2[["row_nbr"]] = 1:nrow(new_data2) + # loop over rules and stotre covered rows out = vector("list", nrow(rulelist)) for (rn in 1:nrow(rulelist)) { @@ -161,12 +539,21 @@ predict_all_nokeys_rulelist = function(rulelist, new_data){ out[[rn]] = new_data2$row_nbr[mask] } - res = - tidytable::tidytable(rule_nbr = 1:nrow(rulelist), - row_nbr = out - ) %>% - unnest(row_nbr, keep_empty = TRUE) %>% - tidytable::full_join(tidytable::tidytable(row_nbr = 1:nrow(new_data))) + # unnest row_nbr (list column of integers) + res = tidytable::tidytable(rule_nbr = rulelist$rule_nbr, + row_nbr = out + ) + + all_nested_row_nbrs_are_null = all(purrr::map_lgl(res$row_nbr, rlang::is_null)) + if (all_nested_row_nbrs_are_null) { + res[["row_nbr"]] = NA_integer_ + } else { + res = unnest(res, row_nbr) + } + + res = tidytable::full_join(res, + tidytable::tidytable(row_nbr = 1:nrow(new_data)) + ) return(res) } @@ -182,7 +569,7 @@ predict_all_nokeys_rulelist = function(rulelist, new_data){ predict_all_rulelist = function(rulelist, new_data){ new_data = data.table::as.data.table(new_data) - keys = attr(rulelist, "keys", exact = TRUE) + keys = attr(rulelist, "keys") if (is.null(keys)) { @@ -198,10 +585,14 @@ predict_all_rulelist = function(rulelist, new_data){ res = rulelist %>% as.data.frame() %>% - nest(data = tidytable::everything(), .by = keys) %>% - mutate(rn_df = purrr::map(data, ~ predict_all_nokeys_rulelist(.x, new_data))) %>% - select(-data) %>% - unnest(rn_df) %>% + nest(data__ = tidytable::everything(), .by = keys) %>% + mutate(rn_df__ = + purrr::map(data__, + ~ predict_all_nokeys_rulelist(.x, new_data) + ) + ) %>% + select(-data__) %>% + unnest(rn_df__) %>% drop_na(row_nbr) %>% select(all_of(c("row_nbr", keys, "rule_nbr"))) %>% arrange(!!!rlang::syms(c("row_nbr", keys, "rule_nbr"))) %>% @@ -226,6 +617,7 @@ predict_nokeys_rulelist = function(rulelist, new_data){ new_data2 = rlang::duplicate(new_data) new_data2[["row_nbr"]] = 1:nrow(new_data2) + # loop through rules and keep removing covered rows from new_data2 out = vector("list", nrow(rulelist)) for (rn in 1:nrow(rulelist)) { @@ -242,12 +634,21 @@ predict_nokeys_rulelist = function(rulelist, new_data){ } } - res = - tidytable::tidytable(rule_nbr = 1:nrow(rulelist), - row_nbr = out - ) %>% - unnest(row_nbr, keep_empty = TRUE) %>% - tidytable::full_join(tidytable::tidytable(row_nbr = 1:nrow(new_data))) + # unnest row_nbr (list column of integers) + res = tidytable::tidytable(rule_nbr = rulelist$rule_nbr, + row_nbr = out + ) + + all_nested_row_nbrs_are_null = all(purrr::map_lgl(res$row_nbr, rlang::is_null)) + if (all_nested_row_nbrs_are_null) { + res[["row_nbr"]] = NA_integer_ + } else { + res = unnest(res, row_nbr) + } + + res = tidytable::full_join(res, + tidytable::tidytable(row_nbr = 1:nrow(new_data)) + ) return(res) } @@ -263,7 +664,7 @@ predict_nokeys_rulelist = function(rulelist, new_data){ predict_rulelist = function(rulelist, new_data){ new_data = data.table::as.data.table(new_data) - keys = attr(rulelist, "keys", exact = TRUE) + keys = attr(rulelist, "keys") if (is.null(keys)) { @@ -277,10 +678,12 @@ predict_rulelist = function(rulelist, new_data){ res = rulelist %>% as.data.frame() %>% - nest(data = tidytable::everything(), .by = keys) %>% - mutate(rn_df = purrr::map(data, ~ predict_nokeys_rulelist(.x, new_data))) %>% - select(-data) %>% - unnest(rn_df) %>% + nest(data__ = tidytable::everything(), .by = keys) %>% + mutate(rn_df__ = + purrr::map(data__, ~ predict_nokeys_rulelist(.x, new_data)) + ) %>% + select(-data__) %>% + unnest(rn_df__) %>% drop_na(row_nbr) %>% select(all_of(c("row_nbr", keys, "rule_nbr"))) %>% arrange(!!!rlang::syms(c("row_nbr", keys, "rule_nbr"))) @@ -293,16 +696,21 @@ predict_rulelist = function(rulelist, new_data){ #' @title `predict` method for a [rulelist] #' @description Predicts `rule_nbr` applicable (as per the order in rulelist) #' for a `row_nbr` (per key) in new_data +#' #' @param object A [rulelist] #' @param new_data (dataframe) #' @param multiple (flag, default: FALSE) Whether to output all rule numbers #' applicable for a row. If FALSE, the first satisfying rule is provided. #' @param ... unused -#' @return dataframe. See **Details**. +#' +#' @returns A dataframe. See **Details**. +#' #' @details If a `row_nbr` is covered more than one `rule_nbr` per 'keys', then #' `rule_nbr` appearing earlier (as in row order of the [rulelist]) takes #' precedence. -#' ## Output Format +#' +#' ## Output Format +#' #' - When multiple is `FALSE`(default), output is a dataframe with three #' or more columns: `row_number` (int), columns corresponding to 'keys', #' `rule_nbr` (int). @@ -327,12 +735,17 @@ predict_rulelist = function(rulelist, new_data){ #' #' output_2 = predict(tidy_c5, palmerpenguins::penguins, multiple = TRUE) #' output_2 # `rule_nbr` is a list-column of integer vectors -#' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist] +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @importFrom stats predict #' @family Core Rulelist Utility #' @export +#' predict.rulelist = function(object, new_data, multiple = FALSE, ...){ + validate_rulelist(object) checkmate::assert_data_frame(new_data) checkmate::assert_flag(multiple) @@ -345,148 +758,1150 @@ predict.rulelist = function(object, new_data, multiple = FALSE, ...){ return(res) } -################################################################################ -#### coerce from dataframe -################################################################################ +#### augment ---- -#' @name as_rulelist -#' @title as_rulelist generic from [tidyrules][package_tidyrules] package -#' @description as_rulelist generic -#' @param x object to be coerced to a [rulelist] -#' @param ... for methods to use -#' @return A [rulelist] -#' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist] -#' @family Core Rulelist Utility -#' @export -as_rulelist = function(x, ...){ - UseMethod("as_rulelist", x) +#' @keywords internal +#' @name augment_class_no_keys +#' @title as the name says +#' @description as the name says +#' not to be exported +augment_class_no_keys = function(x, new_data, y_name, weight, ...){ + + # raw predictions + pred_df = + predict(x, new_data, multiple = TRUE) %>% + unnest(rule_nbr) %>% + select(row_nbr, rule_nbr) + + # new_data with rule_nbr and 'keys' + new_data_with_rule_nbr = + new_data %>% + mutate(row_nbr = row_number()) %>% + mutate(weight__ = local(weight)) %>% + left_join(pred_df, by = "row_nbr") %>% + left_join(select(x, rule_nbr, RHS), by = "rule_nbr") + + prevalence_df = + new_data_with_rule_nbr %>% + summarise(prevalence_0 = sum(weight__, na.rm = TRUE), + .by = eval(y_name) + ) %>% + drop_na(prevalence_0) %>% + mutate(prevalence = prevalence_0 / sum(prevalence_0)) %>% + select(all_of(c(eval(y_name), "prevalence"))) + + aggregatees_df = + new_data_with_rule_nbr %>% + # bring 'prevalence' column + left_join(prevalence_df,by = eval(y_name)) %>% + summarise( + support = sum(weight__, na.rm = TRUE), + confidence = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE), + weight__, + na.rm = TRUE + ), + lift = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE), + weight__, + na.rm = TRUE + ) / prevalence[1], + .by = rule_nbr + ) %>% + nest(.by = rule_nbr, .key = "augmented_stats") + + # output has all columns of 'tidy' along with 'augment_stats' + res = + x %>% + left_join(aggregatees_df, by = c("rule_nbr")) %>% + arrange(rule_nbr) + + return(res) } -#' @name as_rulelist.data.frame -#' @title as_rulelist method for a data.frame -#' @description Convert a set of rules in a dataframe to a [rulelist] -#' @param x dataframe to be coerced to a [rulelist] -#' @param keys (character vector, default: NULL) column names which form the key -#' @param model_type (string, default: NULL) Name of the model which generated -#' the rules -#' @param estimation_type (string) One among: 'regression', -#' 'classification' -#' @param ... currently unused -#' @return [rulelist] object -#' @details Input dataframe should contain these columns: `rule_nbr`, `LHS`, -#' `RHS`. Providing other inputs helps augment better. +#' @keywords internal +#' @name augment_class_keys +#' @title as the name says +#' @description as the name says +#' not to be exported +augment_class_keys = function(x, new_data, y_name, weight, ...){ + + keys = attr(x, "keys") + + # raw predictions + # columns: row_nbr, rule_nbr, `keys` + pred_df = + predict(x, new_data, multi = TRUE) %>% + unnest(rule_nbr) # columns: row_nbr, rule_nbr, `keys` + + # new_data with rule_nbr and 'keys' + # columns: row_nbr, rule_nbr, `keys`, RHS, columns of new_data + new_data_with_rule_nbr = + # new_data with row_nbr and weight__ columns + new_data %>% + mutate(row_nbr = row_number()) %>% + mutate(weight__ = weight) %>% + # bring rule_nbr, `keys` (multiple rows per row_nbr might get created) + inner_join(pred_df, by = "row_nbr") %>% + # bring RHS column from tidy object + inner_join(select(x, all_of(c("rule_nbr", keys, "RHS"))), + by = c(keys, "rule_nbr") + ) + + # prevalence per 'keys' + prevalence_df = + new_data_with_rule_nbr %>% + summarise(prevalence_0 = sum(weight__, na.rm = TRUE), + .by = c(keys, eval(y_name)) + ) %>% + drop_na(prevalence_0) %>% + mutate(prevalence = prevalence_0 / sum(prevalence_0, na.rm = TRUE), + .by = c(keys) + ) %>% + select(all_of(c(keys, eval(y_name), "prevalence"))) + + # add aggregates at rule_nbr and 'keys' level + aggregatees_df = + new_data_with_rule_nbr %>% + left_join(prevalence_df, by = c(keys, eval(y_name))) %>% + summarise( + support = sum(weight__, na.rm = TRUE), + confidence = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE), + weight__, + na.rm = TRUE + ), + lift = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE), + weight__, + na.rm = TRUE + ) / prevalence[1], + ..., + .by = c(keys, "rule_nbr") + ) %>% + nest(.by = c("rule_nbr", keys), .key = "augmented_stats") + + # output has all columns of 'tidy' along with 'augment_stats' + res = + x %>% + left_join(aggregatees_df, by = c("rule_nbr", keys)) %>% + arrange(!!!rlang::syms(c(keys, "rule_nbr"))) %>% + relocate(all_of(c("rule_nbr", keys))) + + return(res) +} + +#' @keywords internal +#' @name augment_regr_no_keys +#' @title as the name says +#' @description as the name says +#' not to be exported +augment_regr_no_keys = function(x, new_data, y_name, weight, ...){ + + # raw predictions + pred_df = + predict(x, new_data, multiple = TRUE) %>% + unnest(rule_nbr) %>% + select(row_nbr, rule_nbr) + + # new_data with rule_nbr and 'keys' + new_data_with_rule_nbr = + new_data %>% + mutate(row_nbr = row_number()) %>% + mutate(weight__ = local(weight)) %>% + left_join(pred_df, by = "row_nbr") %>% + left_join(select(x, rule_nbr, RHS), by = "rule_nbr") + + if (is.character(x$RHS)) { + new_data_with_rule_nbr = + new_data_with_rule_nbr %>% + nest(.by = c("RHS", "row_nbr")) %>% + mutate(RHS = purrr::map2_dbl(RHS, + data, + ~ eval(parse(text = .x), envir = .y) + ) + ) %>% + unnest(data) + } + + aggregatees_df = + new_data_with_rule_nbr %>% + summarise( + support = sum(weight__, na.rm = TRUE), + IQR = DescTools::IQRw(.data[[y_name]], weight__, na.rm = TRUE), + RMSE = MetricsWeighted::rmse(actual = .data[[y_name]], + predicted = RHS, + w = weight__, + na.rm = TRUE + ), + .by = rule_nbr + ) %>% + nest(.by = rule_nbr, .key = "augmented_stats") + + # output has all columns of 'tidy' along with 'augment_stats' + res = + x %>% + left_join(aggregatees_df, by = c("rule_nbr")) %>% + arrange(rule_nbr) + + return(res) +} + +#' @keywords internal +#' @name augment_regr_keys +#' @title as the name says +#' @description as the name says +#' not to be exported +augment_regr_keys = function(x, new_data, y_name, weight, ...){ + + keys = attr(x, "keys") + + # raw predictions + # columns: row_nbr, rule_nbr, `keys` + pred_df = + predict(x, new_data, multi = TRUE) %>% + unnest(rule_nbr) # columns: row_nbr, rule_nbr, `keys` + + # new_data with rule_nbr and 'keys' + # columns: row_nbr, rule_nbr, `keys`, RHS, columns of new_data + new_data_with_rule_nbr = + # new_data with row_nbr and weight__ columns + new_data %>% + mutate(row_nbr = row_number()) %>% + mutate(weight__ = weight) %>% + # bring rule_nbr, `keys` (multiple rows per row_nbr might get created) + inner_join(pred_df, by = "row_nbr") %>% + # bring RHS column from tidy object + inner_join(select(x, all_of(c("rule_nbr", keys, "RHS"))), + by = c(keys, "rule_nbr") + ) + + if (is.character(x$RHS)) { + new_data_with_rule_nbr = + new_data_with_rule_nbr %>% + nest(.by = c("RHS", keys, "row_nbr")) %>% + mutate(RHS = purrr::map2_dbl(RHS, + data, + ~ eval(parse(text = .x), envir = .y) + ) + ) %>% + unnest(data) + } + + aggregatees_df = + new_data_with_rule_nbr %>% + summarise( + support = sum(weight__, na.rm = TRUE), + IQR = DescTools::IQRw(.data[[y_name]], weight__, na.rm = TRUE), + RMSE = MetricsWeighted::rmse(actual = .data[[y_name]], + predicted = RHS, + w = weight__, + na.rm = TRUE + ), + .by = c(keys, "rule_nbr") + ) %>% + nest(.by = c("rule_nbr", keys), .key = "augmented_stats") + + # output has all columns of 'tidy' along with 'augment_stats' + res = + x %>% + left_join(aggregatees_df, by = c("rule_nbr", keys)) %>% + arrange(!!!rlang::syms(c(keys, "rule_nbr"))) %>% + relocate(all_of(c("rule_nbr", keys))) + + return(res) +} + +#' @name augment +#' @title `augment` is re-export of [generics::augment] from +#' [tidyrules][package_tidyrules] package +#' @description See [augment.rulelist] +#' +#' @param x A [rulelist] +#' @param ... For methods to use +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @importFrom generics augment +#' @export +generics::augment + +#' @name augment.rulelist +#' @title Augment a [rulelist] +#' @description `augment` outputs a [rulelist] with an additional column named +#' `augmented_stats` based on summary statistics calculated using attribute +#' `validation_data`. +#' @param x A [rulelist] +#' @param ... (expressions) To be send to [tidytable::summarise] for custom +#' aggregations. See examples. +#' @returns A [rulelist] with a new dataframe-column named `augmented_stats`. +#' @details The dataframe-column `augmented_stats` will have these columns +#' corresponding to the `estimation_type`: +#' +#' - For `regression`: `support`, `IQR`, `RMSE` +#' - For `classification`: `support`, `confidence`, `lift` +#' +#' along with custom aggregations. +#' +#' #' @examples -#' rules_df = tidytable::tidytable(rule_nbr = 1:2, -#' LHS = c("var_1 > 50", "var_2 < 30"), -#' RHS = c(2, 1) -#' ) -#' as_rulelist(rules_df, estimation_type = "regression") -#' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist] -#' @family Core Rulelist Utility +#' # Examples for augment ------------------------------------------------------ +#' library("magrittr") +#' +#' # C5 ---- +#' att = modeldata::attrition +#' set.seed(100) +#' train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) +#' +#' model_c5 = C50::C5.0(Attrition ~., data = att[train_index, ], rules = TRUE) +#' tidy_c5 = +#' model_c5 %>% +#' tidy() %>% +#' set_validation_data(att[!train_index, ], "Attrition") +#' +#' tidy_c5 +#' +#' augment(tidy_c5) %>% +#' tidytable::unnest(augmented_stats, names_sep = "__") %>% +#' tidytable::glimpse() +#' +#' # augment with custom aggregator +#' augment(tidy_c5,output_counts = list(table(Attrition))) %>% +#' tidytable::unnest(augmented_stats, names_sep = "__") %>% +#' tidytable::glimpse() +#' +#' # rpart ---- +#' set.seed(100) +#' train_index = sample(c(TRUE, FALSE), nrow(iris), replace = TRUE) +#' +#' model_class_rpart = rpart::rpart(Species ~ ., data = iris[train_index, ]) +#' tidy_class_rpart = tidy(model_class_rpart) %>% +#' set_validation_data(iris[!train_index, ], "Species") +#' tidy_class_rpart +#' +#' model_regr_rpart = rpart::rpart(Sepal.Length ~ ., data = iris[train_index, ]) +#' tidy_regr_rpart = tidy(model_regr_rpart) %>% +#' set_validation_data(iris[!train_index, ], "Sepal.Length") +#' tidy_regr_rpart +#' +#' # augment (classification case) +#' augment(tidy_class_rpart) %>% +#' tidytable::unnest(augmented_stats, names_sep = "__") %>% +#' tidytable::glimpse() +#' +#' # augment (regression case) +#' augment(tidy_regr_rpart) %>% +#' tidytable::unnest(augmented_stats, names_sep = "__") %>% +#' tidytable::glimpse() +#' +#' # party ---- +#' pen = palmerpenguins::penguins %>% +#' tidytable::drop_na(bill_length_mm) +#' set.seed(100) +#' train_index = sample(c(TRUE, FALSE), nrow(pen), replace = TRUE) +#' +#' model_class_party = partykit::ctree(species ~ ., data = pen[train_index, ]) +#' tidy_class_party = tidy(model_class_party) %>% +#' set_validation_data(pen[!train_index, ], "species") +#' tidy_class_party +#' +#' model_regr_party = +#' partykit::ctree(bill_length_mm ~ ., data = pen[train_index, ]) +#' tidy_regr_party = tidy(model_regr_party) %>% +#' set_validation_data(pen[!train_index, ], "bill_length_mm") +#' tidy_regr_party +#' +#' # augment (classification case) +#' augment(tidy_class_party) %>% +#' tidytable::unnest(augmented_stats, names_sep = "__") %>% +#' tidytable::glimpse() +#' +#' # augment (regression case) +#' augment(tidy_regr_party) %>% +#' tidytable::unnest(augmented_stats, names_sep = "__") %>% +#' tidytable::glimpse() +#' +#' # cubist ---- +#' att = modeldata::attrition +#' set.seed(100) +#' train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) +#' cols_att = setdiff(colnames(att), c("MonthlyIncome", "Attrition")) +#' +#' model_cubist = Cubist::cubist(x = att[train_index, cols_att], +#' y = att[train_index, "MonthlyIncome"] +#' ) +#' +#' tidy_cubist = tidy(model_cubist) %>% +#' set_validation_data(att[!train_index, ], "MonthlyIncome") +#' tidy_cubist +#' +#' augment(tidy_cubist) %>% +#' tidytable::unnest(augmented_stats, names_sep = "__") %>% +#' tidytable::glimpse() +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @export -as_rulelist.data.frame = function(x, - keys = NULL, - model_type = NULL, - estimation_type, - ... - ){ - checkmate::assert_character(keys, - min.len = 1, - any.missing = FALSE, - unique = TRUE, - null.ok = TRUE - ) +augment.rulelist = function(x, ...){ - # #### checks - # 1. basic cols exist. - # 2. keys are different from basic cols. - # 3. key columns exist. - # 4. key along with 'rule_nbr' form unique rows without missing values. - # 5. rule_nbr (integerish), LHS(character), RHS(any vector) should not have - # missing values. - # 6. 'estimation_type' should be one among: classification, regression + validate_rulelist(x) + estimation_type = attr(x, "estimation_type") + keys = attr(x, "keys") - # check on basic columns and 'key' columns - basic_cols = c("rule_nbr", "LHS", "RHS") if (is.null(keys)) { - checkmate::assert_subset(basic_cols, colnames(x)) - # create key combo - key_combo_df = distinct(x, rule_nbr) + if (estimation_type == "classification"){ + res = augment_class_no_keys(x, + attr(x, "validation_data"), + attr(x, "y_name"), + attr(x, "weight"), + ... + ) + } else if (estimation_type == "regression") { + res = augment_regr_no_keys(x, + attr(x, "validation_data"), + attr(x, "y_name"), + attr(x, "weight"), + ... + ) + } } else { - # keys should be different from basic cols - if (length(intersect(keys, basic_cols)) > 0) { - rlang::abort("keys should not one among: 'rule_nbr', 'LHS', 'RHS'") + if (estimation_type == "classification"){ + res = augment_class_keys(x, + attr(x, "validation_data"), + attr(x, "y_name"), + attr(x, "weight"), + ... + ) + } else if (estimation_type == "regression") { + res = augment_regr_keys(x, + attr(x, "validation_data"), + attr(x, "y_name"), + attr(x, "weight"), + ... + ) } - # expected columns exist exist - checkmate::assert_subset(c(basic_cols, keys), colnames(x)) - # create key combo - key_combo_df = distinct(select(x, all_of(c("rule_nbr", keys)))) } - checkmate::assert_true(anyDuplicated(key_combo_df) == 0) - checkmate::assert_false(anyNA(key_combo_df)) + return(res) +} - checkmate::assert_integerish(x$rule_nbr, any.missing = FALSE) - checkmate::assert_character(x$LHS, any.missing = FALSE) - checkmate::assert_vector(x$RHS, any.missing = FALSE) +#### metrics ---- - checkmate::assert_string(model_type, null.ok = TRUE) - checkmate::assert_string(estimation_type) - checkmate::assert_subset(estimation_type, c("classification", "regression")) +#' @keywords internal +metric__cumulative_coverage = function(rulelist, new_data, y_name, weight){ - # set class and attributes - res = rlang::duplicate(x) + weight_df = tidytable::tidytable(row_nbr = 1:nrow(new_data), weight = weight) - class(res) = c("rulelist", class(res)) - if (!is.null(model_type)) { - attr(res, "model_type") = model_type + # loop over rules and get coverage + predicted = predict(rulelist, new_data) + in_union = integer(0) + cum_weighted_coverage = numeric(nrow(rulelist)) + + for (i in 1:nrow(rulelist)) { + row_nbrs = + predict(rulelist[i, ], new_data) %>% + drop_na(rule_nbr) %>% + pull(row_nbr) + + in_union = union(in_union, row_nbrs) + cum_weighted_coverage[i] = sum(weight_df[row_nbr %in% in_union, ][["weight"]]) } - attr(res, "estimation_type") = estimation_type + return(cum_weighted_coverage) +} + +#' @keywords internal +metric__cumulative_accuracy = function(rulelist, new_data, y_name, weight){ + + priority_df = + rulelist %>% + select(rule_nbr) %>% + mutate(priority = 1:nrow(rulelist)) %>% + select(rule_nbr, priority) + pred_df = + predict(rulelist, new_data) %>% + mutate(weight = local(weight)) %>% + left_join(priority_df, by = "rule_nbr") %>% + select(rule_nbr, row_nbr, weight, priority) + + new_data2 = + new_data %>% + mutate(row_nbr = 1:n()) %>% + select(all_of(c("row_nbr", y_name))) + + confidence_till_rule = function(rn){ + + pred_df %>% + tidytable::filter(priority <= rn) %>% + left_join(new_data2, by = "row_nbr") %>% + left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") %>% + mutate(hit = as.integer(RHS == .data[[y_name]])) %>% + summarise(conf = weighted.mean(hit, weight, na.rm = TRUE)) %>% + `[[`("conf") + } + + res = purrr::map_dbl(1:nrow(rulelist), confidence_till_rule) return(res) } -################################################################################ -#### set_keys -################################################################################ +#' @keywords internal +metric__cumulative_RMSE = function(rulelist, new_data, y_name, weight){ -#' @name set_keys -#' @title Set keys for a [rulelist] -#' @description 'keys' are a set of column(s) whose unique combination -#' identifies a group of rules in a [rulelist]. Methods like -#' [predict.rulelist], [augment.rulelist] produce output per key combination. + priority_df = + rulelist %>% + select(rule_nbr) %>% + mutate(priority = 1:nrow(rulelist)) %>% + select(rule_nbr, priority) + + pred_df = + predict(rulelist, new_data) %>% + left_join(priority_df, by = "rule_nbr") %>% + mutate(weight = local(weight)) %>% + select(rule_nbr, row_nbr, weight, priority) + + new_data2 = + new_data %>% + mutate(row_nbr = 1:n()) %>% + select(all_of(c("row_nbr", y_name))) + + rmse_till_rule = function(rn){ + + if (is.character(rulelist$RHS)) { + inter_df = + pred_df %>% + tidytable::filter(priority <= rn) %>% + left_join(mutate(new_data, row_nbr = 1:n()), by = "row_nbr") %>% + left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") %>% + nest(.by = c("RHS", "rule_nbr", "row_nbr", "priority", "weight")) %>% + mutate(RHS = purrr::map2_dbl(RHS, + data, + ~ eval(parse(text = .x), envir = .y) + ) + ) %>% + unnest(data) + } else { + + inter_df = + pred_df %>% + tidytable::filter(priority <= rn) %>% + left_join(new_data2, by = "row_nbr") %>% + left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") + } + + inter_df %>% + summarise(rmse = MetricsWeighted::rmse(RHS, + .data[[y_name]], + weight, + na.rm = TRUE + ) + ) %>% + `[[`("rmse") + } + + res = purrr::map_dbl(1:nrow(rulelist), rmse_till_rule) + return(res) +} + +#' @keywords internal +metric__cumulative_overlap = function(rulelist, new_data, y_name, weight){ + weight_df = tidytable::tidytable(row_nbr = 1:nrow(new_data), weight = weight) + + # loop over rules and get coverage + predicted = predict(rulelist, new_data) + in_union = integer(0) + in_overlap = integer(0) + cum_weighted_overlap = numeric(nrow(rulelist)) + + for (i in 1:nrow(rulelist)) { + row_nbrs = + predict(rulelist[i, ], new_data) %>% + drop_na(rule_nbr) %>% + pull(row_nbr) + + in_overlap = union(in_overlap, intersect(in_union, row_nbrs)) + in_union = union(in_union, row_nbrs) + cum_weighted_overlap[i] = sum(weight_df[row_nbr %in% in_overlap, ][["weight"]]) + } + + return(cum_weighted_overlap) +} + +#### calculate ---- + +#' @name calculate +#' @title `calculate` is re-export of [generics::calculate] from +#' [tidyrules][package_tidyrules] package +#' @description See [calculate.rulelist] #' @param x A [rulelist] -#' @param keys (character vector or NULL) -#' @return A [rulelist] object -#' @details A new [rulelist] is returned with attr `keys` is modified. The input -#' [rulelist] object is unaltered. +#' @param ... See [calculate.rulelist] +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @importFrom generics calculate +#' @export +#' +generics::calculate + + +#' @name calculate.rulelist +#' @title `calculate` metrics for a [rulelist] +#' @description Computes some metrics (based on `estimation_type`) in cumulative +#' window function style over the rulelist (in the same order) ignoring the +#' keys. +#' +#' @details ## Default Metrics +#' These metrics are calculated by default: +#' +#' - `cumulative_coverage`: For nth rule in the rulelist, number of distinct `row_nbr`s (of `new_data`) covered by nth and all preceding rules (in order). In weighted case, we sum the weights corresponding to the distinct `row_nbr`s. +#' +#' - `cumulative_overlap`: Up til nth rule in the rulelist, number of distinct `row_nbr`s (of `new_data`) already covered by some preceding rule (in order). In weighted case, we sum the weights corresponding to the distinct `row_nbr`s. +#' +#' For classification: +#' +#' - `cumulative_accuracy`: For nth rule in the rulelist, fraction of `row_nbr`s such that `RHS` matches the `y_name` column (of `new_data`) by nth and all preceding rules (in order). In weighted case, weighted accuracy is computed. +#' +#' For regression: +#' +#' - `cumulative_RMSE`: For nth rule in the rulelist, weighted RMSE of all predictions (`RHS`) predicted by nth rule and all preceding rules. +#' +#' ## Custom metrics +#' +#' Custom metrics to be computed should be passed a named list of function(s) in +#' `...`. The custom metric function should take these arguments in same order: +#' `rulelist`, `new_data`, `y_name`, `weight`. The custom metric function should +#' return a numeric vector of same length as the number of rows of rulelist. +#' +#' @param x A [rulelist] +#' @param metrics_to_exclude (character vector) Names of metrics to exclude +#' +#' @param ... Named list of custom metrics. See 'details'. +#' +#' @returns A dataframe of metrics with a `rule_nbr` column. +#' #' @examples -#' model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) -#' tidy_c5 = tidy(model_c5) -#' tidy_c5 # keys are: "trial_nbr" +#' library("magrittr") +#' model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) +#' tidy_c5 = tidy(model_c5) %>% +#' set_validation_data(modeldata::attrition, "Attrition") %>% +#' set_keys(NULL) #' -#' new_tidy_c5 = set_keys(tidy_c5, NULL) # remove all keys -#' new_tidy_c5 -#' new_2_tidy_c5 = set_keys(new_tidy_c5, "trial_nbr") # set "trial_nbr" as key -#' new_2_tidy_c5 +#' # calculate default metrics (classification) +#' calculate(tidy_c5) #' -#' # Note that `tidy_c5` and `new_tidy_c5` are not altered. -#' tidy_c5 -#' new_tidy_c5 -#' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist] -#' @family Core Rulelist Utility +#' model_rpart = rpart::rpart(MonthlyIncome ~., data = modeldata::attrition) +#' tidy_rpart = +#' tidy(model_rpart) %>% +#' set_validation_data(modeldata::attrition, "MonthlyIncome") %>% +#' set_keys(NULL) +#' +#' # calculate default metrics (regression) +#' calculate(tidy_rpart) +#' +#' # calculate default metrics with a custom metric +#' #' custom function to get cumulative MAE +#' library("tidytable") +#' get_cumulative_MAE = function(rulelist, new_data, y_name, weight){ +#' +#' priority_df = +#' rulelist %>% +#' select(rule_nbr) %>% +#' mutate(priority = 1:nrow(rulelist)) %>% +#' select(rule_nbr, priority) +#' +#' pred_df = +#' predict(rulelist, new_data) %>% +#' left_join(priority_df, by = "rule_nbr") %>% +#' mutate(weight = local(weight)) %>% +#' select(rule_nbr, row_nbr, weight, priority) +#' +#' new_data2 = +#' new_data %>% +#' mutate(row_nbr = 1:n()) %>% +#' select(all_of(c("row_nbr", y_name))) +#' +#' rmse_till_rule = function(rn){ +#' +#' if (is.character(rulelist$RHS)) { +#' inter_df = +#' pred_df %>% +#' tidytable::filter(priority <= rn) %>% +#' left_join(mutate(new_data, row_nbr = 1:n()), by = "row_nbr") %>% +#' left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") %>% +#' nest(.by = c("RHS", "rule_nbr", "row_nbr", "priority", "weight")) %>% +#' mutate(RHS = purrr::map2_dbl(RHS, +#' data, +#' ~ eval(parse(text = .x), envir = .y) +#' ) +#' ) %>% +#' unnest(data) +#' } else { +#' +#' inter_df = +#' pred_df %>% +#' tidytable::filter(priority <= rn) %>% +#' left_join(new_data2, by = "row_nbr") %>% +#' left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") +#' } +#' +#' inter_df %>% +#' summarise(rmse = MetricsWeighted::mae(RHS, +#' .data[[y_name]], +#' weight, +#' na.rm = TRUE +#' ) +#' ) %>% +#' `[[`("rmse") +#' } +#' +#' res = purrr::map_dbl(1:nrow(rulelist), rmse_till_rule) +#' return(res) +#' } +#' +#' calculate(tidy_rpart, +#' metrics_to_exclude = NULL, +#' list("cumulative_mae" = get_cumulative_MAE) +#' ) +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @export -set_keys = function(x, keys){ +calculate.rulelist = function(x, + metrics_to_exclude = NULL, + ... + ){ - checkmate::assert_character(keys, null.ok = TRUE) - if (!is.null(keys)){ - checkmate::assert_subset(keys, colnames(x)) - checkmate::assert_false(any(c("LHS", "RHS", "row_nbr") %in% keys)) + # checks + validate_rulelist(x) + + if (is.null(attr(x, "validation_data"))) { + cli::cli_alert_danger("validation_data is not present! Set using `set_validation_data`") + rlang::abort() } - res = rlang::duplicate(x) - attr(res, "keys") = keys + checkmate::assert_character(metrics_to_exclude, null.ok = TRUE) + + # ignore keys + keys = attr(x, "keys") + if (!is.null(keys)) { + cli::cli_alert_warning("'keys' will be ignored in `calculate`") + if (inherits(try(set_keys(x, NULL), silent = TRUE), "try-error")) { + x = set_keys(x, NULL, reset = TRUE) + } else { + x = set_keys(x, NULL) + } + } + + # set metric names to compute + metric_names = c("cumulative_coverage", "cumulative_overlap") + metric_names = switch( + attr(x, "estimation_type"), + classification = c(metric_names, "cumulative_accuracy"), + regression = c(metric_names, "cumulative_RMSE") + ) + metric_names = setdiff(metric_names, metrics_to_exclude) + + # compute metrics + res = tidytable::tidytable(rule_nbr = x$rule_nbr) + + for (a_metric_name in metric_names) { + metric_func = get(paste("metric", a_metric_name, sep = "__")) + out = metric_func(x, + attr(x, "validation_data"), + attr(x, "y_name"), + attr(x, "weight") + ) + res[[a_metric_name]] = out + } + + # compute udf metrics + extra_metrics = list(...) + if (length(extra_metrics)){ + extra_metrics = list(...)[[1]] + } + + if (length(extra_metrics) > 0) { + # should be a named list of functions + checkmate::assert_list(extra_metrics, + any.missing = FALSE, + types = "function" + ) + checkmate::assert_names(names(extra_metrics), type = "unique") + checkmate::assert(!any(names(extra_metrics) %in% metric_names)) + + for (a_metric_name in names(extra_metrics)) { + metric_func = extra_metrics[[a_metric_name]] + out = metric_func(x, + attr(x, "validation_data"), + attr(x, "y_name"), + attr(x, "weight") + ) + res[[a_metric_name]] = out + } + } + return(res) } + + +#### prune ---- + +#' @name prune +#' @title `prune` is re-export of [generics::prune] from +#' [tidyrules][package_tidyrules] package +#' @description See [prune.rulelist] +#' @param tree A [rulelist] +#' @param ... See [prune.rulelist] +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @importFrom generics prune +#' @export +generics::prune + +#' @name prune.rulelist +#' @title `prune` rules of a [rulelist] +#' @description Prune the [rulelist] by suggesting to keep first 'k' rules based +#' on metrics computed by [calculate][calculate.rulelist] +#' +#' @param tree A [rulelist] +#' @param metrics_to_exclude (character vector or NULL) Names of metrics not to +#' be calculated. See [calculate][calculate.rulelist] for the list of default +#' metrics. +#' @param stop_expr_string (string default: "relative__cumulative_coverage >= +#' 0.9") Parsable condition +#' @param min_n_rules (positive integer) Minumum number of rules to keep +#' @param ... Named list of custom metrics passed to +#' [calculate][calculate.rulelist] +#' +#' @details 1. Metrics are computed using [calculate][calculate.rulelist]. 2. +#' Relative metrics (prepended by 'relative__') are calculated by dividing +#' each metric by its max value. 3. The first rule in rulelist order which +#' meets the 'stop_expr_string' criteria is stored (say 'pos'). Print method +#' suggests to keep rules until pos. +#' +#' @returns Object of class 'prune_ruleslist' with these components: 1. pruned: +#' ruleset keeping only first 'pos' rows. 2. n_pruned_rules: pos. If stop +#' criteria is never met, then pos = nrow(ruleset) 3. n_total_rules: +#' nrow(ruleset), 4. metrics_df: Dataframe with metrics and relative metrics +#' 5. stop_expr_string +#' +#' @examples +#' library("magrittr") +#' model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) +#' tidy_c5 = tidy(model_c5) %>% +#' set_validation_data(modeldata::attrition, "Attrition") %>% +#' set_keys(NULL) +#' +#' #' prune with defaults +#' prune_obj = prune(tidy_c5) +#' #' note that all other metrics are visible in the print output +#' prune_obj +#' plot(prune_obj) +#' prune_obj$pruned +#' +#' #' prune with a different stop_expr_string threshold +#' prune_obj = prune(tidy_c5, +#' stop_expr_string = "relative__cumulative_coverage >= 0.2" +#' ) +#' prune_obj #' as expected, has smaller then 10 rules as compared to default args +#' plot(prune_obj) +#' prune_obj$pruned +#' +#' #' prune with a different stop_expr_string metric +#' st = "relative__cumulative_overlap <= 0.7 & relative__cumulative_overlap > 0" +#' prune_obj = prune(tidy_c5, stop_expr_string = st) +#' prune_obj #' as expected, has smaller then 10 rules as compared to default args +#' plot(prune_obj) +#' prune_obj$pruned +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @export +prune.rulelist = function( + tree, + metrics_to_exclude = NULL, + stop_expr_string = "relative__cumulative_coverage >= 0.9", + min_n_rules = 1, + ... + ){ + x = tree + validate_rulelist(x) + + # get metrics + metrics_df = calculate(x, metrics_to_exclude, ...) + + # create "relative__" metrics with minmax scaling + relative_metrics_df = + metrics_df %>% + mutate(across(-rule_nbr, + list(relative = ~ .x / max(.x, na.rm = TRUE)), + .names = "{.fn}__{.col}" + ) + ) + + mask = eval(parse(text = stop_expr_string), envir = relative_metrics_df) + pos = purrr::detect_index(mask, isTRUE) + if (pos == 0){ pos = nrow(x)} + pos = max(min_n_rules, pos) + + res = list(pruned = head(x, pos), + n_pruned_rules = pos, + n_total_rules = nrow(x), + metrics_df = relative_metrics_df, + stop_expr_string = stop_expr_string + ) + class(res) = c("prune_rulelist", class(res)) + return(res) +} + +#' @name print.prune_rulelist +#' @title Print method for `prune_rulelist` class +#' @description Print method for `prune_rulelist` class +#' @param x A 'prune_rulelist' object +#' @param ... unused +#' @export +print.prune_rulelist = function(x, ...) { + + cli::cli_rule("Prune Suggestion") + cli::cli_text() + if (x$n_pruned_rules < x$n_total_rules) { + cli::cli_alert_success( + glue::glue("Keep first {x$n_pruned_rules} out of {x$n_total_rules}")) + } else { + cli::cli_alert_danger( + glue::glue("Stop criteria is not met. Pruning is not possible.")) + } + + cli::cli_text() + cli::cli_alert_info(glue::glue("Metrics after {x$n_pruned_rules} rules: ")) + + x$metrics_df %>% + slice(x$n_pruned_rules) %>% + tidytable::pivot_longer(-rule_nbr, names_to = "metric", values_to = "value") %>% + print() + + cli::cli_text() + cli::cli_alert_info("Run `plot(x)` for details; `x$pruned` to get pruned rulelist") + cli::cli_rule() + + return(invisible(x)) +} + +#' @name plot.prune_rulelist +#' @title Plot method for `prune_rulelist` class +#' @description Plot method for `prune_rulelist` class +#' @param x A 'prune_rulelist' object +#' @param ... unused +#' @returns ggplot2 object (invisibly) +#' @export +plot.prune_rulelist = function(x, ...) { + + data_for_plot = + x$metrics_df %>% + mutate(rule_nbr = factor(as.character(rule_nbr), + levels = as.character(rule_nbr) + ) + ) %>% + select(rule_nbr, tidytable::starts_with("relative__")) %>% + pivot_longer(-rule_nbr, names_to = "metric", values_to = "value") + + n_total_rules = x[["n_total_rules"]] + n_pruned_rules = x[["n_pruned_rules"]] + + plot(x = 1:n_total_rules, + y = runif(n_total_rules), + type = "n", # No points initially + xaxt = "n", # Don't draw x-axis yet + yaxt = "n", + xlab = "rule_nbr", + ylab = "value" + ) + + for (metric_level in unique(data_for_plot$metric)) { + + subset = data_for_plot[data_for_plot$metric == metric_level,] + # Plot lines with color based on metric + lines(subset$rule_nbr, + subset$value, + col = which(unique(data_for_plot$metric) == metric_level), + type = "l" + ) + } + axis(1, at = 1:n_total_rules, labels = unique(data_for_plot$rule_nbr), las = 2) + axis(2, at = NULL, ylim = c(-0.1, 1.1)) + abline(v = n_pruned_rules, col = "darkgreen", lty = 3) + + legend("bottomright", + legend = unique(data_for_plot$metric), + col = 1:length(unique(data_for_plot$metric)), + pch = 19 + ) +} + +#### reorder ---- + +#' @name reorder +#' @title reorder generic +#' @description reorder generic for rulelist +#' @param x A [rulelist] +#' @param ... See [reorder.rulelist] +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @export +reorder = function(x, ...){ + UseMethod("reorder", x) +} + +#' @name reorder.rulelist +#' @title Reorder the rules/rows of a [rulelist] +#' @description Implements a greedy strategy to add one rule at a time which +#' maximizes/minimizes a metric. +#' +#' @param x A [rulelist] +#' @param metric (character vector or named list) Name of metrics or a custom +#' function(s). See [calculate][calculate.rulelist]. The 'n+1'th metric is +#' used when there is a match at 'nth' level, similar to [base::order]. If +#' there is a match at final level, row order of the rulelist comes into play. +#' @param minimize (logical vector) Whether to minimize. Either TRUE/FALSE or a +#' logical vector of same length as metric +#' @param init (positive integer) Initial number of rows after which reordering +#' should begin +#' @param ... passed to [calculate][calculate.rulelist] +#' +#' @examples +#' library("magrittr") +#' att = modeldata::attrition +#' tidy_c5 = +#' C50::C5.0(Attrition ~., data = att, rules = TRUE) %>% +#' tidy() %>% +#' set_validation_data(att, "Attrition") %>% +#' set_keys(NULL) %>% +#' head(5) +#' +#' # with defaults +#' reorder(tidy_c5) +#' +#' # use 'cumulative_overlap' to break ties (if any) +#' reorder(tidy_c5, metric = c("cumulative_coverage", "cumulative_overlap")) +#' +#' # reorder after 2 rules +#' reorder(tidy_c5, init = 2) +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] +#' @export +#' +reorder.rulelist = function(x, + metric = "cumulative_coverage", + minimize = FALSE, + init = NULL, + ... + ){ + # checks + validate_rulelist(x) + + keys = attr(x, "keys") + validation_data = attr(x, "validation_data") + estimation_type = attr(x, "estimation_type") + model_type = attr(x, "model_type") + y_name = attr(x, "y_name") + weight = attr(x, "weight") + + # ignore keys + if (!is.null(keys)) { + cli::cli_alert_warning("'keys' will be ignored in `reorder`") + if (inherits(try(set_keys(x, NULL), silent = TRUE), "try-error")) { + x = set_keys(x, NULL, reset = TRUE) + } else { + x = set_keys(x, NULL) + } + keys = attr(x, "keys") + } + + checkmate::assert(!is.null(validation_data)) + checkmate::assert_character(metric) + # checks:done + + # utility function to set to rulelist + to_rulelist = function(df){ + df %>% + as_rulelist(keys = keys, + model_type = model_type, + estimation_type = estimation_type + ) %>% + set_validation_data(validation_data, y_name, weight) + } + + # handle init + checkmate::assert_integerish(init, + len = 1, + lower = 1, + upper = max(1, nrow(x) - 2), + null.ok = TRUE + ) + if (!is.null(init)) { + reordered_df = as.data.frame(x[1:init, ]) + x = x[(init + 1):(nrow(x)), ] # x changes at this point + } else { + reordered_df = NULL + } + + # core process + splitted = split(as.data.frame(x), 1:nrow(x)) + reordered_metrics = vector("list", length = nrow(x)) + + # wrapper where metric gets computed + wrapper_metric_fun = function(single_rule_df){ + + single_row_metric_df = + reordered_df %>% + bind_rows(single_rule_df) %>% + to_rulelist() %>% + calculate(...) %>% + select(all_of(metric)) %>% + tail(1) + + return(single_row_metric_df) + } + + # get init metrics + if (!is.null(init)) { + init_metrics = calculate(to_rulelist(reordered_df), ...) %>% + select(all_of(metric)) + } else { + init_metrics = NULL + } + + # loop through surviving rules + cli::cli_progress_bar("Reordering ...", clear = FALSE) + for (i in 1:nrow(x)) { + rule_metrics = purrr::map_dfr(splitted, wrapper_metric_fun) + ord = do.call(base::order, + c(rule_metrics, + list(decreasing = minimize) + ) + ) + pos = which(ord == 1) + reordered_metrics[[i]] = rule_metrics[pos, ] + reordered_df = bind_rows(reordered_df, splitted[[pos]]) + splitted[[pos]] = NULL + + cli::cli_progress_update() + } + cli::cli_progress_done() + + # put metrics df in same shape as x + reordered_metrics = + bind_rows(reordered_metrics) %>% + bind_rows(init_metrics, .) + + # put metrics to the right of reordered x + reordered_df = bind_cols(reordered_df, reordered_metrics) + + # return + return(to_rulelist(reordered_df)) +} diff --git a/R/tidy.R b/R/tidy.R index 2fa526e..b87a387 100644 --- a/R/tidy.R +++ b/R/tidy.R @@ -8,13 +8,17 @@ #' [tidyrules][package_tidyrules] package #' @description `tidy` applied on a supported model fit creates a [rulelist]. #' **See Also** section links to documentation of specific methods. +#' #' @param x A supported model object #' @param ... For model specific implementations to use -#' @seealso [tidy], [tidy.C5.0], [tidy.rpart], [tidy.constparty], [tidy.cubist], -#' [rulelist], [augment][augment.rulelist], [predict][predict.rulelist] +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @importFrom generics tidy #' @family Core Tidy Utility #' @export +#' generics::tidy #' @name tidy.C5.0 @@ -38,11 +42,12 @@ generics::tidy #' model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) #' tidy(model_c5) #' -#' @seealso [tidy], [tidy.C5.0], [tidy.rpart], [tidy.constparty], [tidy.cubist], -#' [rulelist], [augment.rulelist], [predict.rulelist] +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @family Core Tidy Utility #' @export - +#' tidy.C5.0 = function(x, ...){ #### checks ################################################################# @@ -371,11 +376,13 @@ tidy.C5.0 = function(x, ...){ #' #' model_regr_rpart = rpart::rpart(Sepal.Length ~ ., data = iris) #' tidy(model_regr_rpart) -#' @seealso [tidy], [tidy.C5.0], [tidy.rpart], [tidy.constparty], [tidy.cubist], -#' [rulelist], [augment.rulelist], [predict.rulelist] +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @family Core Tidy Utility #' @export - +#' tidy.rpart = function(x, ...){ ##### assertions and prep #################################################### @@ -514,11 +521,12 @@ tidy.rpart = function(x, ...){ #' model_regr_party = partykit::ctree(bill_length_mm ~ ., data = pen) #' tidy(model_regr_party) -#' @seealso [tidy], [tidy.C5.0], [tidy.rpart], [tidy.constparty], [tidy.cubist], -#' [rulelist], [augment.rulelist], [predict.rulelist] -#' @family Core Tidy Utility +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @export - +#' tidy.constparty = function(x, ...){ ##### assertions and prep #################################################### @@ -687,11 +695,13 @@ tidy.constparty = function(x, ...){ #' y = att[["MonthlyIncome"]] #' ) #' tidy(model_cubist) -#' @seealso [tidy], [tidy.C5.0], [tidy.rpart], [tidy.constparty], [tidy.cubist], -#' [rulelist], [augment.rulelist], [predict.rulelist] +#' +#' @seealso [rulelist], [tidy], [augment][augment.rulelist], +#' [predict][predict.rulelist], [calculate][calculate.rulelist], +#' [prune][prune.rulelist], [reorder][reorder.rulelist] #' @family Core Tidy Utility #' @export - +#' tidy.cubist = function(x, ...){ #### core rule extraction #################################################### diff --git a/R/utils.R b/R/utils.R index 693faa5..d2787c3 100644 --- a/R/utils.R +++ b/R/utils.R @@ -332,7 +332,7 @@ convert_rule_flavor = function(rule, flavor){ #' @name to_sql_case #' @title Extract SQL case statement from a [rulelist] #' @description Extract SQL case statement from a [rulelist] -#' @param x A [rulelist] object +#' @param rulelist A [rulelist] object #' @param rhs_column_name (string, default: "RHS") Name of the column in the #' rulelist to be used as RHS (WHEN THEN {rhs}) in the sql case #' statement @@ -348,22 +348,22 @@ convert_rule_flavor = function(rule, flavor){ #' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist], [convert_rule_flavor] #' @family Auxiliary Rulelist Utility #' @export -to_sql_case = function(x, +to_sql_case = function(rulelist, rhs_column_name = "RHS", output_colname = "output" ){ - checkmate::assert_class(x, "rulelist") - rhs_is_string = inherits(x[[rhs_column_name]], c("character", "factor")) - lhs_sql = convert_rule_flavor(x$LHS, flavor = "sql") + checkmate::assert_class(rulelist, "rulelist") + rhs_is_string = inherits(rulelist[[rhs_column_name]], c("character", "factor")) + lhs_sql = convert_rule_flavor(rulelist$LHS, flavor = "sql") out = "CASE" - for (rn in seq_len(nrow(x))) { + for (rn in seq_len(nrow(rulelist))) { if (rhs_is_string) { - lhs = glue::glue("WHEN {lhs_sql[rn]} THEN '{x[[rhs_column_name]][rn]}'") + lhs = glue::glue("WHEN {lhs_sql[rn]} THEN '{rulelist[[rhs_column_name]][rn]}'") } else { - lhs = glue::glue("WHEN {lhs_sql[rn]} THEN {x[[rhs_column_name]][rn]}") + lhs = glue::glue("WHEN {lhs_sql[rn]} THEN {rulelist[[rhs_column_name]][rn]}") } out = paste(out, lhs, sep = "\n") } diff --git a/man/as_rulelist.Rd b/man/as_rulelist.Rd index a0ab4e7..8a1d16c 100644 --- a/man/as_rulelist.Rd +++ b/man/as_rulelist.Rd @@ -18,12 +18,7 @@ A \link{rulelist} as_rulelist generic } \seealso{ -\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, \link[=predict.rulelist]{predict} - -Other Core Rulelist Utility: -\code{\link{as_rulelist.data.frame}()}, -\code{\link{predict.rulelist}()}, -\code{\link{print.rulelist}()}, -\code{\link{set_keys}()} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} } -\concept{Core Rulelist Utility} diff --git a/man/as_rulelist.data.frame.Rd b/man/as_rulelist.data.frame.Rd index 8985b72..7459d1e 100644 --- a/man/as_rulelist.data.frame.Rd +++ b/man/as_rulelist.data.frame.Rd @@ -37,12 +37,7 @@ rules_df = tidytable::tidytable(rule_nbr = 1:2, as_rulelist(rules_df, estimation_type = "regression") } \seealso{ -\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, \link[=predict.rulelist]{predict} - -Other Core Rulelist Utility: -\code{\link{as_rulelist}()}, -\code{\link{predict.rulelist}()}, -\code{\link{print.rulelist}()}, -\code{\link{set_keys}()} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} } -\concept{Core Rulelist Utility} diff --git a/man/augment.Rd b/man/augment.Rd index c79eb74..77f5fe3 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/augment.R +% Please edit documentation in R/rulelist.R \name{augment} \alias{augment} \title{\code{augment} is re-export of \link[generics:augment]{generics::augment} from @@ -16,9 +16,7 @@ augment(x, ...) See \link{augment.rulelist} } \seealso{ -\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, \link[=predict.rulelist]{predict} - -Other Augment: -\code{\link{augment.rulelist}()} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} } -\concept{Augment} diff --git a/man/augment.rulelist.Rd b/man/augment.rulelist.Rd index bf9ad13..9b91f21 100644 --- a/man/augment.rulelist.Rd +++ b/man/augment.rulelist.Rd @@ -1,21 +1,14 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/augment.R +% Please edit documentation in R/rulelist.R \name{augment.rulelist} \alias{augment.rulelist} \title{Augment a \link{rulelist}} \usage{ -\method{augment}{rulelist}(x, new_data, y_name, weight = 1L, ...) +\method{augment}{rulelist}(x, ...) } \arguments{ \item{x}{A \link{rulelist}} -\item{new_data}{(dataframe) with column named \code{y_name} present} - -\item{y_name}{(string) Column name representing the dependent variable} - -\item{weight}{(numeric, default: 1) Positive weight vector with length equal -to one or number of rows of 'new_data'} - \item{...}{(expressions) To be send to \link[tidytable:summarize]{tidytable::summarise} for custom aggregations. See examples.} } @@ -24,7 +17,8 @@ A \link{rulelist} with a new dataframe-column named \code{augmented_stats}. } \description{ \code{augment} outputs a \link{rulelist} with an additional column named -\code{augmented_stats} based on summary statistics calculated using \code{new_data}. +\code{augmented_stats} based on summary statistics calculated using attribute +\code{validation_data}. } \details{ The dataframe-column \code{augmented_stats} will have these columns @@ -34,8 +28,7 @@ corresponding to the \code{estimation_type}: \item For \code{classification}: \code{support}, \code{confidence}, \code{lift} } -All these metrics are computed in a weighted sense. Arg \code{weight} is 1 by -default. +along with custom aggregations. } \examples{ # Examples for augment ------------------------------------------------------ @@ -47,25 +40,19 @@ set.seed(100) train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) model_c5 = C50::C5.0(Attrition ~., data = att[train_index, ], rules = TRUE) -tidy_c5 = tidy(model_c5) -tidy_c5 +tidy_c5 = + model_c5 \%>\% + tidy() \%>\% + set_validation_data(att[!train_index, ], "Attrition") -# augment -augmented = augment(tidy_c5, new_data = att[!train_index, ], y_name = "Attrition") +tidy_c5 -augmented \%>\% +augment(tidy_c5) \%>\% tidytable::unnest(augmented_stats, names_sep = "__") \%>\% tidytable::glimpse() # augment with custom aggregator -augmented = - augment(tidy_c5, - new_data = att[!train_index, ], - y_name = "Attrition", - output_counts = list(table(Attrition)) - ) - -augmented \%>\% +augment(tidy_c5,output_counts = list(table(Attrition))) \%>\% tidytable::unnest(augmented_stats, names_sep = "__") \%>\% tidytable::glimpse() @@ -74,71 +61,49 @@ set.seed(100) train_index = sample(c(TRUE, FALSE), nrow(iris), replace = TRUE) model_class_rpart = rpart::rpart(Species ~ ., data = iris[train_index, ]) -tidy_class_rpart = tidy(model_class_rpart) +tidy_class_rpart = tidy(model_class_rpart) \%>\% + set_validation_data(iris[!train_index, ], "Species") tidy_class_rpart model_regr_rpart = rpart::rpart(Sepal.Length ~ ., data = iris[train_index, ]) -tidy_regr_rpart = tidy(model_regr_rpart) +tidy_regr_rpart = tidy(model_regr_rpart) \%>\% + set_validation_data(iris[!train_index, ], "Sepal.Length") tidy_regr_rpart -#' augment (classification case) -augmented = - augment(tidy_class_rpart, - new_data = iris[!train_index, ], - y_name = "Species" - ) -augmented - -augmented \%>\% +# augment (classification case) +augment(tidy_class_rpart) \%>\% tidytable::unnest(augmented_stats, names_sep = "__") \%>\% tidytable::glimpse() -#' augment (regression case) -augmented = - augment(tidy_regr_rpart, - new_data = iris[!train_index, ], - y_name = "Sepal.Length" - ) -augmented - -augmented \%>\% +# augment (regression case) +augment(tidy_regr_rpart) \%>\% tidytable::unnest(augmented_stats, names_sep = "__") \%>\% tidytable::glimpse() # party ---- -pen = palmerpenguins::penguins +pen = palmerpenguins::penguins \%>\% + tidytable::drop_na(bill_length_mm) set.seed(100) train_index = sample(c(TRUE, FALSE), nrow(pen), replace = TRUE) model_class_party = partykit::ctree(species ~ ., data = pen[train_index, ]) -tidy_class_party = tidy(model_class_party) +tidy_class_party = tidy(model_class_party) \%>\% + set_validation_data(pen[!train_index, ], "species") tidy_class_party -model_regr_party = partykit::ctree(bill_length_mm ~ ., data = pen[train_index, ]) -tidy_regr_party = tidy(model_regr_party) +model_regr_party = + partykit::ctree(bill_length_mm ~ ., data = pen[train_index, ]) +tidy_regr_party = tidy(model_regr_party) \%>\% + set_validation_data(pen[!train_index, ], "bill_length_mm") tidy_regr_party -#' augment (classification case) -augmented = - augment(tidy_class_party, - new_data = pen[!train_index, ], - y_name = "species" - ) -augmented - -augmented \%>\% +# augment (classification case) +augment(tidy_class_party) \%>\% tidytable::unnest(augmented_stats, names_sep = "__") \%>\% tidytable::glimpse() -#' augment (regression case) -augmented = - augment(tidy_regr_party, - new_data = tidytable::drop_na(pen[!train_index, ], bill_length_mm), - y_name = "bill_length_mm" - ) -augmented - -augmented \%>\% +# augment (regression case) +augment(tidy_regr_party) \%>\% tidytable::unnest(augmented_stats, names_sep = "__") \%>\% tidytable::glimpse() @@ -152,25 +117,17 @@ model_cubist = Cubist::cubist(x = att[train_index, cols_att], y = att[train_index, "MonthlyIncome"] ) -tidy_cubist = tidy(model_cubist) +tidy_cubist = tidy(model_cubist) \%>\% + set_validation_data(att[!train_index, ], "MonthlyIncome") tidy_cubist -augmented = - augment(tidy_cubist, - new_data = att[!train_index, ], - y_name = "MonthlyIncome" - ) -augmented - -augmented \%>\% +augment(tidy_cubist) \%>\% tidytable::unnest(augmented_stats, names_sep = "__") \%>\% tidytable::glimpse() } \seealso{ -\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, \link[=predict.rulelist]{predict} - -Other Augment: -\code{\link{augment}()} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} } -\concept{Augment} diff --git a/man/augment_class_keys.Rd b/man/augment_class_keys.Rd index 715c90d..e8cd02a 100644 --- a/man/augment_class_keys.Rd +++ b/man/augment_class_keys.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/augment.R +% Please edit documentation in R/rulelist.R \name{augment_class_keys} \alias{augment_class_keys} \title{as the name says} \usage{ -augment_class_keys(x, new_data, y_name, weight = 1L, ...) +augment_class_keys(x, new_data, y_name, weight, ...) } \description{ as the name says diff --git a/man/augment_class_no_keys.Rd b/man/augment_class_no_keys.Rd index 628504d..8a4b400 100644 --- a/man/augment_class_no_keys.Rd +++ b/man/augment_class_no_keys.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/augment.R +% Please edit documentation in R/rulelist.R \name{augment_class_no_keys} \alias{augment_class_no_keys} \title{as the name says} \usage{ -augment_class_no_keys(x, new_data, y_name, weight = 1L, ...) +augment_class_no_keys(x, new_data, y_name, weight, ...) } \description{ as the name says diff --git a/man/augment_regr_keys.Rd b/man/augment_regr_keys.Rd index 60d1463..66978a1 100644 --- a/man/augment_regr_keys.Rd +++ b/man/augment_regr_keys.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/augment.R +% Please edit documentation in R/rulelist.R \name{augment_regr_keys} \alias{augment_regr_keys} \title{as the name says} \usage{ -augment_regr_keys(x, new_data, y_name, weight = 1L, ...) +augment_regr_keys(x, new_data, y_name, weight, ...) } \description{ as the name says diff --git a/man/augment_regr_no_keys.Rd b/man/augment_regr_no_keys.Rd index b060b38..8703d90 100644 --- a/man/augment_regr_no_keys.Rd +++ b/man/augment_regr_no_keys.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/augment.R +% Please edit documentation in R/rulelist.R \name{augment_regr_no_keys} \alias{augment_regr_no_keys} \title{as the name says} \usage{ -augment_regr_no_keys(x, new_data, y_name, weight = 1L, ...) +augment_regr_no_keys(x, new_data, y_name, weight, ...) } \description{ as the name says diff --git a/man/calculate.Rd b/man/calculate.Rd new file mode 100644 index 0000000..b7aa50b --- /dev/null +++ b/man/calculate.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{calculate} +\alias{calculate} +\title{\code{calculate} is re-export of \link[generics:calculate]{generics::calculate} from +\link[=package_tidyrules]{tidyrules} package} +\usage{ +calculate(x, ...) +} +\arguments{ +\item{x}{A \link{rulelist}} + +\item{...}{See \link{calculate.rulelist}} +} +\description{ +See \link{calculate.rulelist} +} +\seealso{ +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} +} diff --git a/man/calculate.rulelist.Rd b/man/calculate.rulelist.Rd new file mode 100644 index 0000000..e0fb91f --- /dev/null +++ b/man/calculate.rulelist.Rd @@ -0,0 +1,141 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{calculate.rulelist} +\alias{calculate.rulelist} +\title{\code{calculate} metrics for a \link{rulelist}} +\usage{ +\method{calculate}{rulelist}(x, metrics_to_exclude = NULL, ...) +} +\arguments{ +\item{x}{A \link{rulelist}} + +\item{metrics_to_exclude}{(character vector) Names of metrics to exclude} + +\item{...}{Named list of custom metrics. See 'details'.} +} +\value{ +A dataframe of metrics with a \code{rule_nbr} column. +} +\description{ +Computes some metrics (based on \code{estimation_type}) in cumulative +window function style over the rulelist (in the same order) ignoring the +keys. +} +\details{ +\subsection{Default Metrics}{ + +These metrics are calculated by default: +\itemize{ +\item \code{cumulative_coverage}: For nth rule in the rulelist, number of distinct \code{row_nbr}s (of \code{new_data}) covered by nth and all preceding rules (in order). In weighted case, we sum the weights corresponding to the distinct \code{row_nbr}s. +\item \code{cumulative_overlap}: Up til nth rule in the rulelist, number of distinct \code{row_nbr}s (of \code{new_data}) already covered by some preceding rule (in order). In weighted case, we sum the weights corresponding to the distinct \code{row_nbr}s. +} + +For classification: +\itemize{ +\item \code{cumulative_accuracy}: For nth rule in the rulelist, fraction of \code{row_nbr}s such that \code{RHS} matches the \code{y_name} column (of \code{new_data}) by nth and all preceding rules (in order). In weighted case, weighted accuracy is computed. +} + +For regression: +\itemize{ +\item \code{cumulative_RMSE}: For nth rule in the rulelist, weighted RMSE of all predictions (\code{RHS}) predicted by nth rule and all preceding rules. +} +} + +\subsection{Custom metrics}{ + +Custom metrics to be computed should be passed a named list of function(s) in +\code{...}. The custom metric function should take these arguments in same order: +\code{rulelist}, \code{new_data}, \code{y_name}, \code{weight}. The custom metric function should +return a numeric vector of same length as the number of rows of rulelist. +} +} +\examples{ +library("magrittr") +model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) +tidy_c5 = tidy(model_c5) \%>\% + set_validation_data(modeldata::attrition, "Attrition") \%>\% + set_keys(NULL) + +# calculate default metrics (classification) +calculate(tidy_c5) + +model_rpart = rpart::rpart(MonthlyIncome ~., data = modeldata::attrition) +tidy_rpart = + tidy(model_rpart) \%>\% + set_validation_data(modeldata::attrition, "MonthlyIncome") \%>\% + set_keys(NULL) + +# calculate default metrics (regression) +calculate(tidy_rpart) + +# calculate default metrics with a custom metric +#' custom function to get cumulative MAE +library("tidytable") +get_cumulative_MAE = function(rulelist, new_data, y_name, weight){ + + priority_df = + rulelist \%>\% + select(rule_nbr) \%>\% + mutate(priority = 1:nrow(rulelist)) \%>\% + select(rule_nbr, priority) + + pred_df = + predict(rulelist, new_data) \%>\% + left_join(priority_df, by = "rule_nbr") \%>\% + mutate(weight = local(weight)) \%>\% + select(rule_nbr, row_nbr, weight, priority) + + new_data2 = + new_data \%>\% + mutate(row_nbr = 1:n()) \%>\% + select(all_of(c("row_nbr", y_name))) + + rmse_till_rule = function(rn){ + + if (is.character(rulelist$RHS)) { + inter_df = + pred_df \%>\% + tidytable::filter(priority <= rn) \%>\% + left_join(mutate(new_data, row_nbr = 1:n()), by = "row_nbr") \%>\% + left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") \%>\% + nest(.by = c("RHS", "rule_nbr", "row_nbr", "priority", "weight")) \%>\% + mutate(RHS = purrr::map2_dbl(RHS, + data, + ~ eval(parse(text = .x), envir = .y) + ) + ) \%>\% + unnest(data) + } else { + + inter_df = + pred_df \%>\% + tidytable::filter(priority <= rn) \%>\% + left_join(new_data2, by = "row_nbr") \%>\% + left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") + } + + inter_df \%>\% + summarise(rmse = MetricsWeighted::mae(RHS, + .data[[y_name]], + weight, + na.rm = TRUE + ) + ) \%>\% + `[[`("rmse") + } + + res = purrr::map_dbl(1:nrow(rulelist), rmse_till_rule) + return(res) +} + +calculate(tidy_rpart, + metrics_to_exclude = NULL, + list("cumulative_mae" = get_cumulative_MAE) + ) + +} +\seealso{ +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} +} diff --git a/man/plot.prune_rulelist.Rd b/man/plot.prune_rulelist.Rd new file mode 100644 index 0000000..b662061 --- /dev/null +++ b/man/plot.prune_rulelist.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{plot.prune_rulelist} +\alias{plot.prune_rulelist} +\title{Plot method for \code{prune_rulelist} class} +\usage{ +\method{plot}{prune_rulelist}(x, ...) +} +\arguments{ +\item{x}{A 'prune_rulelist' object} + +\item{...}{unused} +} +\value{ +ggplot2 object (invisibly) +} +\description{ +Plot method for \code{prune_rulelist} class +} diff --git a/man/plot.rulelist.Rd b/man/plot.rulelist.Rd new file mode 100644 index 0000000..bcd9a6e --- /dev/null +++ b/man/plot.rulelist.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{plot.rulelist} +\alias{plot.rulelist} +\title{Plot method for rulelist} +\usage{ +\method{plot}{rulelist}(x, thres_cluster_rows = 1000, dist_metric = "jaccard", ...) +} +\arguments{ +\item{x}{A \link{rulelist}} + +\item{thres_cluster_rows}{(positive integer) Maximum number of rows beyond +which a x-side dendrogram is not computed} + +\item{dist_metric}{(string or function, default: "jaccard") Distance metric +for y-side (\code{rule_nbr}) passed to \code{method} argument of \link[proxy:dist]{proxy::dist}} + +\item{...}{Arguments to be passed to \link[pheatmap:pheatmap]{pheatmap::pheatmap}} +} +\description{ +Plots a heatmap with \code{rule_nbr}'s on x-side and clusters of +\code{row_nbr}'s on y-side of a binary matrix with 1 if a rule is applicable for +a row. +} +\details{ +Number of clusters is set to min(number of unique rows in the +row_nbr X rule_nbr matrix and thres_cluster_rows) +} +\examples{ +library("magrittr") +att = modeldata::attrition +tidy_c5 = + C50::C5.0(Attrition ~., data = att, rules = TRUE) \%>\% + tidy() \%>\% + set_validation_data(att, "Attrition") \%>\% + set_keys(NULL) + +plot(tidy_c5) + +} diff --git a/man/predict.rulelist.Rd b/man/predict.rulelist.Rd index 5175313..08a8eb1 100644 --- a/man/predict.rulelist.Rd +++ b/man/predict.rulelist.Rd @@ -17,7 +17,7 @@ applicable for a row. If FALSE, the first satisfying rule is provided.} \item{...}{unused} } \value{ -dataframe. See \strong{Details}. +A dataframe. See \strong{Details}. } \description{ Predicts \code{rule_nbr} applicable (as per the order in rulelist) @@ -53,14 +53,15 @@ output_1 # different rules per 'keys' (`trial_nbr` here) output_2 = predict(tidy_c5, palmerpenguins::penguins, multiple = TRUE) output_2 # `rule_nbr` is a list-column of integer vectors + } \seealso{ -\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, \link[=predict.rulelist]{predict} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} Other Core Rulelist Utility: -\code{\link{as_rulelist}()}, -\code{\link{as_rulelist.data.frame}()}, -\code{\link{print.rulelist}()}, -\code{\link{set_keys}()} +\code{\link{set_keys}()}, +\code{\link{set_validation_data}()} } \concept{Core Rulelist Utility} diff --git a/man/print.prune_rulelist.Rd b/man/print.prune_rulelist.Rd new file mode 100644 index 0000000..047f276 --- /dev/null +++ b/man/print.prune_rulelist.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{print.prune_rulelist} +\alias{print.prune_rulelist} +\title{Print method for \code{prune_rulelist} class} +\usage{ +\method{print}{prune_rulelist}(x, ...) +} +\arguments{ +\item{x}{A 'prune_rulelist' object} + +\item{...}{unused} +} +\description{ +Print method for \code{prune_rulelist} class +} diff --git a/man/print.rulelist.Rd b/man/print.rulelist.Rd index 2718106..b27443d 100644 --- a/man/print.rulelist.Rd +++ b/man/print.rulelist.Rd @@ -18,12 +18,7 @@ input \link{rulelist} (invisibly) Prints \link{rulelist} attributes and first few rows. } \seealso{ -\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, \link[=predict.rulelist]{predict} - -Other Core Rulelist Utility: -\code{\link{as_rulelist}()}, -\code{\link{as_rulelist.data.frame}()}, -\code{\link{predict.rulelist}()}, -\code{\link{set_keys}()} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} } -\concept{Core Rulelist Utility} diff --git a/man/prune.Rd b/man/prune.Rd new file mode 100644 index 0000000..b28e4bf --- /dev/null +++ b/man/prune.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{prune} +\alias{prune} +\title{\code{prune} is re-export of \link[generics:prune]{generics::prune} from +\link[=package_tidyrules]{tidyrules} package} +\usage{ +prune(tree, ...) +} +\arguments{ +\item{tree}{A \link{rulelist}} + +\item{...}{See \link{prune.rulelist}} +} +\description{ +See \link{prune.rulelist} +} +\seealso{ +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} +} diff --git a/man/prune.rulelist.Rd b/man/prune.rulelist.Rd new file mode 100644 index 0000000..af189d7 --- /dev/null +++ b/man/prune.rulelist.Rd @@ -0,0 +1,84 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{prune.rulelist} +\alias{prune.rulelist} +\title{\code{prune} rules of a \link{rulelist}} +\usage{ +\method{prune}{rulelist}( + tree, + metrics_to_exclude = NULL, + stop_expr_string = "relative__cumulative_coverage >= 0.9", + min_n_rules = 1, + ... +) +} +\arguments{ +\item{tree}{A \link{rulelist}} + +\item{metrics_to_exclude}{(character vector or NULL) Names of metrics not to +be calculated. See \link[=calculate.rulelist]{calculate} for the list of default +metrics.} + +\item{stop_expr_string}{(string default: "relative__cumulative_coverage >= +0.9") Parsable condition} + +\item{min_n_rules}{(positive integer) Minumum number of rules to keep} + +\item{...}{Named list of custom metrics passed to +\link[=calculate.rulelist]{calculate}} +} +\value{ +Object of class 'prune_ruleslist' with these components: 1. pruned: +ruleset keeping only first 'pos' rows. 2. n_pruned_rules: pos. If stop +criteria is never met, then pos = nrow(ruleset) 3. n_total_rules: +nrow(ruleset), 4. metrics_df: Dataframe with metrics and relative metrics +5. stop_expr_string +} +\description{ +Prune the \link{rulelist} by suggesting to keep first 'k' rules based +on metrics computed by \link[=calculate.rulelist]{calculate} +} +\details{ +\enumerate{ +\item Metrics are computed using \link[=calculate.rulelist]{calculate}. 2. +Relative metrics (prepended by 'relative__') are calculated by dividing +each metric by its max value. 3. The first rule in rulelist order which +meets the 'stop_expr_string' criteria is stored (say 'pos'). Print method +suggests to keep rules until pos. +} +} +\examples{ +library("magrittr") +model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) +tidy_c5 = tidy(model_c5) \%>\% + set_validation_data(modeldata::attrition, "Attrition") \%>\% + set_keys(NULL) + +#' prune with defaults +prune_obj = prune(tidy_c5) +#' note that all other metrics are visible in the print output +prune_obj +plot(prune_obj) +prune_obj$pruned + +#' prune with a different stop_expr_string threshold +prune_obj = prune(tidy_c5, + stop_expr_string = "relative__cumulative_coverage >= 0.2" + ) +prune_obj #' as expected, has smaller then 10 rules as compared to default args +plot(prune_obj) +prune_obj$pruned + +#' prune with a different stop_expr_string metric +st = "relative__cumulative_overlap <= 0.7 & relative__cumulative_overlap > 0" +prune_obj = prune(tidy_c5, stop_expr_string = st) +prune_obj #' as expected, has smaller then 10 rules as compared to default args +plot(prune_obj) +prune_obj$pruned + +} +\seealso{ +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} +} diff --git a/man/reorder.Rd b/man/reorder.Rd new file mode 100644 index 0000000..89d6f6c --- /dev/null +++ b/man/reorder.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{reorder} +\alias{reorder} +\title{reorder generic} +\usage{ +reorder(x, ...) +} +\arguments{ +\item{x}{A \link{rulelist}} + +\item{...}{See \link{reorder.rulelist}} +} +\description{ +reorder generic for rulelist +} +\seealso{ +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} +} diff --git a/man/reorder.rulelist.Rd b/man/reorder.rulelist.Rd new file mode 100644 index 0000000..dbb13e8 --- /dev/null +++ b/man/reorder.rulelist.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{reorder.rulelist} +\alias{reorder.rulelist} +\title{Reorder the rules/rows of a \link{rulelist}} +\usage{ +\method{reorder}{rulelist}(x, metric = "cumulative_coverage", minimize = FALSE, init = NULL, ...) +} +\arguments{ +\item{x}{A \link{rulelist}} + +\item{metric}{(character vector or named list) Name of metrics or a custom +function(s). See \link[=calculate.rulelist]{calculate}. The 'n+1'th metric is +used when there is a match at 'nth' level, similar to \link[base:order]{base::order}. If +there is a match at final level, row order of the rulelist comes into play.} + +\item{minimize}{(logical vector) Whether to minimize. Either TRUE/FALSE or a +logical vector of same length as metric} + +\item{init}{(positive integer) Initial number of rows after which reordering +should begin} + +\item{...}{passed to \link[=calculate.rulelist]{calculate}} +} +\description{ +Implements a greedy strategy to add one rule at a time which +maximizes/minimizes a metric. +} +\examples{ +library("magrittr") +att = modeldata::attrition +tidy_c5 = + C50::C5.0(Attrition ~., data = att, rules = TRUE) \%>\% + tidy() \%>\% + set_validation_data(att, "Attrition") \%>\% + set_keys(NULL) \%>\% + head(5) + +# with defaults +reorder(tidy_c5) + +# use 'cumulative_overlap' to break ties (if any) +reorder(tidy_c5, metric = c("cumulative_coverage", "cumulative_overlap")) + +# reorder after 2 rules +reorder(tidy_c5, init = 2) + +} +\seealso{ +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} +} diff --git a/man/rulelist.Rd b/man/rulelist.Rd index f3516c6..4221de6 100644 --- a/man/rulelist.Rd +++ b/man/rulelist.Rd @@ -52,6 +52,12 @@ A rulelist has these optional attributes: \item \code{keys}: (character vector)Names of the column that forms a key. \item \code{model_type}: (string) Name of the model } + +\subsection{Set Validation data}{ + +This helps a few methods like \link{augment}, \link{calculate}, \link{prune}, \link{reorder} +require few additional attributes which can be set using +\link{set_validation_data}. } \subsection{Methods for rulelist}{ @@ -60,9 +66,16 @@ A rulelist has these optional attributes: dependent variable column aka 'test data'), predicts the first rule (as ordered in the rulelist) per 'keys' that is applicable for each row. When \code{multiple = TRUE}, returns all rules applicable for a row (per key). -\item \link[=augment.rulelist]{Augment}: Given a dataframe (with dependent variable -column, aka validation data), creates summary statistics per rule and -returns a rulelist with a new dataframe-column. +\item \link[=augment.rulelist]{Augment}: Outputs summary statistics per rule over +validation data and returns a rulelist with a new dataframe-column. +\item \link[=calculate.rulelist]{Calculate}: Computes metrics for a rulelist in a +cumulative manner such as \code{cumulative_coverage}, \code{cumulative_overlap}, +\code{cumulative_accuracy}. +\item \link[=prune.rulelist]{Prune}: Suggests pruning a rulelist such that some +expectation are met (based on metrics). Example: cumulative_coverage of 80\% +can be met with a first few rules. +\item \link[=reorder.rulelist]{Reorder}: Reorders a rulelist in order to maximize a +metric. } } @@ -77,14 +90,17 @@ dataframe worlds. \subsection{Utilities for a rulelist}{ \enumerate{ \item \link[=as_rulelist.data.frame]{as_rulelist}: Create a \code{rulelist} from a -dataframe with some mandatory columns. 2. \link{set_keys}: Set or Unset 'keys' -of a \code{rulelist}. 3. \link{to_sql_case}: Outputs a SQL case statement for a -\code{rulelist}. 4. \link{convert_rule_flavor}: Converts \code{R}-parsable rule strings to -python/SQL parsable rule strings. +dataframe with some mandatory columns. +\item \link{set_keys}: Set or Unset 'keys' of a \code{rulelist}. +\item \link{to_sql_case}: Outputs a SQL case statement for a \code{rulelist}. +\item \link{convert_rule_flavor}: Converts \code{R}-parsable rule strings to python/SQL +parsable rule strings. +} } } } \seealso{ \link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, -\link[=predict.rulelist]{predict} +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} } diff --git a/man/set_keys.Rd b/man/set_keys.Rd index d6c535b..f8c4cf1 100644 --- a/man/set_keys.Rd +++ b/man/set_keys.Rd @@ -4,20 +4,23 @@ \alias{set_keys} \title{Set keys for a \link{rulelist}} \usage{ -set_keys(x, keys) +set_keys(x, keys, reset = FALSE) } \arguments{ \item{x}{A \link{rulelist}} \item{keys}{(character vector or NULL)} + +\item{reset}{(flag) Whether to reset the keys to sequential numbers startign +with 1 when \code{keys} is set to NULL} } \value{ A \link{rulelist} object } \description{ -'keys' are a set of column(s) whose unique combination -identifies a group of rules in a \link{rulelist}. Methods like -\link{predict.rulelist}, \link{augment.rulelist} produce output per key combination. +'keys' are a set of column(s) which identify a group of rules in +a \link{rulelist}. Methods like \link[=predict.rulelist]{predict}, +\link[=augment.rulelist]{augment} produce output per key combination. } \details{ A new \link{rulelist} is returned with attr \code{keys} is modified. The input @@ -28,22 +31,25 @@ model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) tidy_c5 = tidy(model_c5) tidy_c5 # keys are: "trial_nbr" +tidy_c5[["rule_nbr"]] = 1:nrow(tidy_c5) new_tidy_c5 = set_keys(tidy_c5, NULL) # remove all keys new_tidy_c5 + new_2_tidy_c5 = set_keys(new_tidy_c5, "trial_nbr") # set "trial_nbr" as key new_2_tidy_c5 # Note that `tidy_c5` and `new_tidy_c5` are not altered. tidy_c5 new_tidy_c5 + } \seealso{ -\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, \link[=predict.rulelist]{predict} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} Other Core Rulelist Utility: -\code{\link{as_rulelist}()}, -\code{\link{as_rulelist.data.frame}()}, \code{\link{predict.rulelist}()}, -\code{\link{print.rulelist}()} +\code{\link{set_validation_data}()} } \concept{Core Rulelist Utility} diff --git a/man/set_validation_data.Rd b/man/set_validation_data.Rd new file mode 100644 index 0000000..9192b34 --- /dev/null +++ b/man/set_validation_data.Rd @@ -0,0 +1,58 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rulelist.R +\name{set_validation_data} +\alias{set_validation_data} +\title{Add \code{validation_data} to a \link{rulelist}} +\usage{ +set_validation_data(x, validation_data, y_name, weight = 1) +} +\arguments{ +\item{x}{A \link{rulelist}} + +\item{validation_data}{(dataframe) Data to used for computing some metrics. +It is expected to contain \code{y_name} column.} + +\item{y_name}{(string) Name of the dependent variable column.} + +\item{weight}{(non-negative numeric vector, default: 1) Weight per +observation/row of \code{validation_data}. This is expected to have same length +as the number of rows in \code{validation_data}. Only exception is when it is a +single positive number, which means that all rows have equal weight.} +} +\value{ +A \link{rulelist} with some extra attributes set. +} +\description{ +Returns a \link{rulelist} with three new attributes set: +\code{validation_data}, \code{y_name} and \code{weight}. Methods such as +\link[=augment.rulelist]{augment}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link{reorder} require this to be set. +} +\examples{ +att = modeldata::attrition +set.seed(100) +index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) +model_c5 = C50::C5.0(Attrition ~., data = att[index, ], rules = TRUE) + +tidy_c5 = tidy(model_c5) +tidy_c5 + +tidy_c5_2 = set_validation_data(tidy_c5, + validation_data = att[!index, ], + y_name = "Attrition", + weight = 1 # default + ) +tidy_c5_2 +tidy_c5 # not altered + +} +\seealso{ +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} + +Other Core Rulelist Utility: +\code{\link{predict.rulelist}()}, +\code{\link{set_keys}()} +} +\concept{Core Rulelist Utility} diff --git a/man/tidy.C5.0.Rd b/man/tidy.C5.0.Rd index 8a18370..51445a3 100644 --- a/man/tidy.C5.0.Rd +++ b/man/tidy.C5.0.Rd @@ -38,12 +38,12 @@ tidy(model_c5) } \seealso{ -\link{tidy}, \link{tidy.C5.0}, \link{tidy.rpart}, \link{tidy.constparty}, \link{tidy.cubist}, -\link{rulelist}, \link{augment.rulelist}, \link{predict.rulelist} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} Other Core Tidy Utility: \code{\link{tidy}()}, -\code{\link{tidy.constparty}()}, \code{\link{tidy.cubist}()}, \code{\link{tidy.rpart}()} } diff --git a/man/tidy.Rd b/man/tidy.Rd index 616bb1e..370d7d6 100644 --- a/man/tidy.Rd +++ b/man/tidy.Rd @@ -17,12 +17,12 @@ tidy(x, ...) \strong{See Also} section links to documentation of specific methods. } \seealso{ -\link{tidy}, \link{tidy.C5.0}, \link{tidy.rpart}, \link{tidy.constparty}, \link{tidy.cubist}, -\link{rulelist}, \link[=augment.rulelist]{augment}, \link[=predict.rulelist]{predict} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} Other Core Tidy Utility: \code{\link{tidy.C5.0}()}, -\code{\link{tidy.constparty}()}, \code{\link{tidy.cubist}()}, \code{\link{tidy.rpart}()} } diff --git a/man/tidy.constparty.Rd b/man/tidy.constparty.Rd index 3938991..6d335ad 100644 --- a/man/tidy.constparty.Rd +++ b/man/tidy.constparty.Rd @@ -40,15 +40,10 @@ model_class_party = partykit::ctree(species ~ ., data = pen) tidy(model_class_party) model_regr_party = partykit::ctree(bill_length_mm ~ ., data = pen) tidy(model_regr_party) + } \seealso{ -\link{tidy}, \link{tidy.C5.0}, \link{tidy.rpart}, \link{tidy.constparty}, \link{tidy.cubist}, -\link{rulelist}, \link{augment.rulelist}, \link{predict.rulelist} - -Other Core Tidy Utility: -\code{\link{tidy}()}, -\code{\link{tidy.C5.0}()}, -\code{\link{tidy.cubist}()}, -\code{\link{tidy.rpart}()} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} } -\concept{Core Tidy Utility} diff --git a/man/tidy.cubist.Rd b/man/tidy.cubist.Rd index 0f68970..56c824a 100644 --- a/man/tidy.cubist.Rd +++ b/man/tidy.cubist.Rd @@ -31,15 +31,16 @@ model_cubist = Cubist::cubist(x = att[, cols_att], y = att[["MonthlyIncome"]] ) tidy(model_cubist) + } \seealso{ -\link{tidy}, \link{tidy.C5.0}, \link{tidy.rpart}, \link{tidy.constparty}, \link{tidy.cubist}, -\link{rulelist}, \link{augment.rulelist}, \link{predict.rulelist} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} Other Core Tidy Utility: \code{\link{tidy}()}, \code{\link{tidy.C5.0}()}, -\code{\link{tidy.constparty}()}, \code{\link{tidy.rpart}()} } \concept{Core Tidy Utility} diff --git a/man/tidy.rpart.Rd b/man/tidy.rpart.Rd index 9036f3c..a581be0 100644 --- a/man/tidy.rpart.Rd +++ b/man/tidy.rpart.Rd @@ -39,15 +39,16 @@ tidy(model_class_rpart) model_regr_rpart = rpart::rpart(Sepal.Length ~ ., data = iris) tidy(model_regr_rpart) + } \seealso{ -\link{tidy}, \link{tidy.C5.0}, \link{tidy.rpart}, \link{tidy.constparty}, \link{tidy.cubist}, -\link{rulelist}, \link{augment.rulelist}, \link{predict.rulelist} +\link{rulelist}, \link{tidy}, \link[=augment.rulelist]{augment}, +\link[=predict.rulelist]{predict}, \link[=calculate.rulelist]{calculate}, +\link[=prune.rulelist]{prune}, \link[=reorder.rulelist]{reorder} Other Core Tidy Utility: \code{\link{tidy}()}, \code{\link{tidy.C5.0}()}, -\code{\link{tidy.constparty}()}, \code{\link{tidy.cubist}()} } \concept{Core Tidy Utility} diff --git a/man/to_sql_case.Rd b/man/to_sql_case.Rd index efd2ab4..58f756c 100644 --- a/man/to_sql_case.Rd +++ b/man/to_sql_case.Rd @@ -4,10 +4,10 @@ \alias{to_sql_case} \title{Extract SQL case statement from a \link{rulelist}} \usage{ -to_sql_case(x, rhs_column_name = "RHS", output_colname = "output") +to_sql_case(rulelist, rhs_column_name = "RHS", output_colname = "output") } \arguments{ -\item{x}{A \link{rulelist} object} +\item{rulelist}{A \link{rulelist} object} \item{rhs_column_name}{(string, default: "RHS") Name of the column in the rulelist to be used as RHS (WHEN \if{html}{\out{}} THEN {rhs}) in the sql case diff --git a/tests/testthat/test-rulelist.R b/tests/testthat/test-rulelist.R index 6dc5d2c..ec02aa9 100644 --- a/tests/testthat/test-rulelist.R +++ b/tests/testthat/test-rulelist.R @@ -5,18 +5,21 @@ context("test-rulelist") +#### test predict (without setting validation data) ---- + +pen = palmerpenguins::penguins model_c5 = C50::C5.0(species ~., - data = palmerpenguins::penguins, + data = pen, trials = 5, rules = TRUE ) tidy_c5 = tidy(model_c5) tidy_c5 -output_1 = predict(tidy_c5, palmerpenguins::penguins) +output_1 = predict(tidy_c5, pen) output_1 # different rules per 'keys' (`trial_nbr` here) -output_2 = predict(tidy_c5, palmerpenguins::penguins, raw = TRUE) +output_2 = predict(tidy_c5, pen, multiple = TRUE) output_2 # `rule_nbr` is a list-column of integer vectors test_that("creates a dataframe", { @@ -24,6 +27,349 @@ test_that("creates a dataframe", { expect_is(output_2, "data.frame") }) -test_that("should not miss any row_nbr", { - expect_true(all(1:nrow(palmerpenguins::penguins) %in% output_1$row_nbr)) -}) \ No newline at end of file +test_that("check output column types", { + expect_true(rlang::is_atomic(output_1$rule_nbr)) + expect_true(rlang::is_list(output_2$rule_nbr)) +}) + +test_that("all row_nbr and keys combo exists", { + all_combos = + tidytable::expand_grid(row_nbr = 1:nrow(pen), + trial_nbr = unique(tidy_c5$trial_nbr) + ) + + expected_to_be_empty_df = + output_1 %>% + distinct(row_nbr, trial_nbr) %>% + data.table::fsetdiff(all_combos) + + expect_true(nrow(expected_to_be_empty_df) == 0) + + expected_to_be_empty_df = + output_1 %>% + distinct(row_nbr, trial_nbr) %>% + data.table::fsetdiff(all_combos) + + expect_true(nrow(expected_to_be_empty_df) == 0) +}) + +#### test set_keys ---- + +test_that("tests for set_keys", { + att = modeldata::attrition + model_c5 = C50::C5.0(Attrition ~ ., data = att, rules = TRUE) + tidy_c5 = tidy(model_c5) + + tidy_c5[["rule_nbr"]] = 1:nrow(tidy_c5) + new_tidy_c5 = set_keys(tidy_c5, NULL) # remove all keys + expect_true(is.null(attr(new_tidy_c5, "keys"))) + + new_2_tidy_c5 = set_keys(new_tidy_c5, "trial_nbr") # set "trial_nbr" as key + expect_true(attr(new_2_tidy_c5, "keys") == "trial_nbr") + + # check for no modification + expect_true(is.null(attr(new_tidy_c5, "keys"))) +}) +#### test set_validation_data ---- +test_that("test setting validation data", { + att = modeldata::attrition + set.seed(100) + index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) + model_c5 = C50::C5.0(Attrition ~., data = att[index, ], rules = TRUE) + + tidy_c5 = tidy(model_c5) + + tidy_c5_2 = set_validation_data(tidy_c5, + validation_data = att[!index, ], + y_name = "Attrition", + weight = 1 # default + ) + expect_false(is.null(attr(tidy_c5_2, "validation_data"))) + expect_true(is.null(attr(tidy_c5, "validation_data"))) # not altered + +}) + +#### test as_rulelist ---- + +test_that("test as_rulelist", { + rules_df = tidytable::tidytable(rule_nbr = 1:2, + LHS = c("var_1 > 50", "var_2 < 30"), + RHS = c(2, 1) + ) + output = as_rulelist(rules_df, estimation_type = "regression") + expect_true(inherits(output, "rulelist")) +}) + +#### test print ---- +test_that("test setting validation data", { + att = modeldata::attrition + model_c5 = C50::C5.0(Attrition ~., data = att, rules = TRUE) + tidy_c5 = tidy(model_c5) + + tidy_c5_2 = set_validation_data(tidy_c5, + validation_data = att, + y_name = "Attrition", + weight = 1 # default + ) + res = print(tidy_c5_2) + expect_equal(res, tidy_c5_2) + + expect_equal(print(tidy_c5), tidy_c5) + +}) + +#### test plot ---- +test_that("test plot", { + library("magrittr") + att = modeldata::attrition + + # classification case + tidy_c5 = + C50::C5.0(Attrition ~., data = att, rules = TRUE) %>% + tidy() %>% + set_validation_data(att, "Attrition") %>% + set_keys(NULL) + + res = plot(tidy_c5) + expect_true(inherits(res, "pheatmap")) + + # regression case + tidy_rpart = + rpart::rpart(MonthlyIncome ~., data = att) %>% + tidy() %>% + set_validation_data(att, "MonthlyIncome") %>% + set_keys(NULL) + + res = plot(tidy_rpart) + expect_true(inherits(res, "pheatmap")) +}) + +#### test augment ---- +test_that("test augment", { + library("magrittr") + + # classification case ---- + att = modeldata::attrition + set.seed(100) + train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE) + + model_c5 = C50::C5.0(Attrition ~., data = att[train_index, ], rules = TRUE) + tidy_c5 = + model_c5 %>% + tidy() %>% + set_validation_data(att[!train_index, ], "Attrition") + + output = augment(tidy_c5) + output_unnested = tidytable::unnest(output, + augmented_stats, + names_sep = "__" + ) + expect_true(inherits(output, "rulelist")) + expect_true("augmented_stats" %in% colnames(output)) + expect_true(all(c("augmented_stats__support", + "augmented_stats__confidence", + "augmented_stats__lift" + ) %in% colnames(output_unnested) + ) + ) + + # regression case + set.seed(100) + train_index = sample(c(TRUE, FALSE), nrow(iris), replace = TRUE) + + model_regr_rpart = rpart::rpart(Sepal.Length ~ ., data = iris[train_index, ]) + tidy_regr_rpart = tidy(model_regr_rpart) %>% + set_validation_data(iris[!train_index, ], "Sepal.Length") + + output = augment(tidy_regr_rpart) + output_unnested = tidytable::unnest(output, + augmented_stats, + names_sep = "__" + ) + expect_true(inherits(output, "rulelist")) + expect_true("augmented_stats" %in% colnames(output)) + expect_true(all(c("augmented_stats__support", + "augmented_stats__RMSE", + "augmented_stats__IQR" + ) %in% colnames(output_unnested) + ) + ) + + # augment with custom aggregator + output = augment(tidy_c5, output_counts = list(table(Attrition))) + output_unnested = tidytable::unnest(output, + augmented_stats, + names_sep = "__" + ) + expect_true(inherits(output, "rulelist")) + expect_true("augmented_stats" %in% colnames(output)) + expect_true(all(c("augmented_stats__output_counts" + ) %in% colnames(output_unnested) + ) + ) +}) + +#### test calculate ---- + +test_that("test calculate", { + library("magrittr") + + # classification + model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) + tidy_c5 = tidy(model_c5) %>% + set_validation_data(modeldata::attrition, "Attrition") %>% + set_keys(NULL) + + res = calculate(tidy_c5) + expect_true(inherits(res, "data.frame")) + class_metrics = c("cumulative_coverage", + "cumulative_overlap", + "cumulative_accuracy" + ) + expect_true(all(class_metrics %in% colnames(res))) + + # regression + model_rpart = rpart::rpart(MonthlyIncome ~., data = modeldata::attrition) + tidy_rpart = + tidy(model_rpart) %>% + set_validation_data(modeldata::attrition, "MonthlyIncome") %>% + set_keys(NULL) + + res = calculate(tidy_rpart) + expect_true(inherits(res, "data.frame")) + regr_metrics = c("cumulative_coverage", + "cumulative_overlap", + "cumulative_RMSE" + ) + expect_true(all(regr_metrics %in% colnames(res))) + + + # calculate default metrics with a custom metric + #' custom function to get cumulative MAE + library("tidytable") + get_cumulative_MAE = function(rulelist, new_data, y_name, weight){ + + priority_df = + rulelist %>% + select(rule_nbr) %>% + mutate(priority = 1:nrow(rulelist)) %>% + select(rule_nbr, priority) + + pred_df = + predict(rulelist, new_data) %>% + left_join(priority_df, by = "rule_nbr") %>% + mutate(weight = local(weight)) %>% + select(rule_nbr, row_nbr, weight, priority) + + new_data2 = + new_data %>% + mutate(row_nbr = 1:n()) %>% + select(all_of(c("row_nbr", y_name))) + + rmse_till_rule = function(rn){ + + if (is.character(rulelist$RHS)) { + inter_df = + pred_df %>% + tidytable::filter(priority <= rn) %>% + left_join(mutate(new_data, row_nbr = 1:n()), by = "row_nbr") %>% + left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") %>% + nest(.by = c("RHS", "rule_nbr", "row_nbr", "priority", "weight")) %>% + mutate(RHS = purrr::map2_dbl(RHS, + data, + ~ eval(parse(text = .x), envir = .y) + ) + ) %>% + unnest(data) + } else { + + inter_df = + pred_df %>% + tidytable::filter(priority <= rn) %>% + left_join(new_data2, by = "row_nbr") %>% + left_join(select(rulelist, rule_nbr, RHS), by = "rule_nbr") + } + + inter_df %>% + summarise(rmse = MetricsWeighted::mae(RHS, + .data[[y_name]], + weight, + na.rm = TRUE + ) + ) %>% + `[[`("rmse") + } + + res = purrr::map_dbl(1:nrow(rulelist), rmse_till_rule) + return(res) + } + + res = calculate(tidy_rpart, + metrics_to_exclude = NULL, + list("cumulative_mae" = get_cumulative_MAE) + ) + expect_true(inherits(res, "data.frame")) + custom_regr_metrics = c("cumulative_coverage", + "cumulative_overlap", + "cumulative_RMSE", + "cumulative_mae" + ) + expect_true(all(custom_regr_metrics %in% colnames(res))) +}) + +#### test prune ---- + +test_that("test prune", { + library("magrittr") + model_c5 = C50::C5.0(Attrition ~., data = modeldata::attrition, rules = TRUE) + tidy_c5 = tidy(model_c5) %>% + set_validation_data(modeldata::attrition, "Attrition") %>% + set_keys(NULL) + + #' prune with defaults + prune_obj = prune(tidy_c5) + #' note that all other metrics are visible in the print output + expect_true(inherits(prune_obj, "prune_rulelist")) + plot(prune_obj) + + #' prune with a different stop_expr_string threshold + prune_obj = prune(tidy_c5, + stop_expr_string = "relative__cumulative_coverage >= 0.2" + ) + prune_obj #' as expected, has smaller then 10 rules as compared to default args + plot(prune_obj) + prune_obj$pruned + + #' prune with a different stop_expr_string metric + prune_obj = prune(tidy_c5, + stop_expr_string = "relative__cumulative_overlap <= 0.7 & relative__cumulative_overlap > 0" + ) + prune_obj #' as expected, has smaller then 10 rules as compared to default args + plot(prune_obj) + prune_obj$pruned +}) + +#### test reorder ---- + +test_that("test reorder", { + library("magrittr") + att = modeldata::attrition + tidy_c5 = + C50::C5.0(Attrition ~., data = att, rules = TRUE) %>% + tidy() %>% + set_validation_data(att, "Attrition") %>% + set_keys(NULL) %>% + head(5) + + # with defaults + res = reorder(tidy_c5) + expect_true(inherits(res, "rulelist")) + + # use 'cumulative_overlap' to break ties (if any) + res = reorder(tidy_c5, metric = c("cumulative_coverage", "cumulative_overlap")) + expect_true(inherits(res, "rulelist")) + + # reorder after 2 rules + res = reorder(tidy_c5, init = 2) + expect_true(inherits(res, "rulelist")) +})