Skip to content

Commit

Permalink
turn character vectors into factors (#54)
Browse files Browse the repository at this point in the history
* turn character vectors into factors

* add news item

* indentation

* always consider char variables

* add test
  • Loading branch information
sbfnk authored May 31, 2024
1 parent acab603 commit 0cb0e63
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 23 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# rbi (development version)

* work with character columns in sparse input data

# rbi 1.0.0

* add github actions for package checking and testing
Expand Down
16 changes: 10 additions & 6 deletions R/bi_read.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold,

for (coord_dim in names(coord_dims)) {
if (!is.null(x$coord_dims[[coord_dim]]) &&
any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) {
any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) {
warning(
"Given coord dimension ", coord_dim, " will override a coord ",
"dimension of the same name in passed libbi object"
Expand Down Expand Up @@ -152,7 +152,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold,

## cache
if ("libbi" %in% class(x) && x$use_cache &&
(missing(file) || file == "output")) {
(missing(file) || file == "output")) {
if (clear_cache) {
x$.cache$data <- NULL
x$.cache$thin <- NULL
Expand Down Expand Up @@ -295,19 +295,23 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold,
rownames(mav) <- seq_len(nrow(mav))

if ("libbi" %in% class(x) && length(x$coord_dims) > 0 &&
var_name %in% names(x$coord_dims) && "coord" %in% colnames(mav)) {
var_name %in% names(x$coord_dims) && "coord" %in% colnames(mav)) {
setnames(mav, "coord", x$coord_dims[[var_name]])
}
if ("libbi" %in% class(x) && length(x$time_dim) == 1 &&
"time" %in% colnames(mav)) {
"time" %in% colnames(mav)) {
setnames(mav, "time", x$time_dim)
}

for (col in colnames(mav)) {
## strip trailing numbers, these indicate duplicate dimensions
dim_col <- sub("\\.[0-9]+$", "", col)
if (!missing(dims) && !is.null(dims) && dim_col %in% names(dims)) {
mav[[col]] <- factor(mav[[col]], labels = dims[[dim_col]])
mav[[col]] <- factor(
mav[[col]] - 1L,
levels = seq_along(dims[[dim_col]]) - 1L,
labels = dims[[dim_col]]
)
} else if (dim_col %in% var_dims[["other"]][[var_name]]) {
mav[[col]] <- mav[[col]] - 1
}
Expand Down Expand Up @@ -339,7 +343,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold,
if (any(class(x) %in% c("character", "libbi"))) nc_close(nc)

if ("libbi" %in% class(x) && x$use_cache &&
(missing(file) || file == "output")) {
(missing(file) || file == "output")) {
if (is.null(x$.cache[["data"]])) {
x$.cache$data <- list()
}
Expand Down
77 changes: 69 additions & 8 deletions R/bi_write.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#' @return A list of the time and coord dims, and factors in extra dimensions,
#' if any
#' @importFrom ncdf4 nc_close ncdim_def ncvar_def nc_create ncvar_put ncvar_add
#' @importFrom data.table data.table copy
#' @importFrom data.table data.table copy rbindlist
#' @importFrom reshape2 melt
#' @examples
#' filename <- tempfile(pattern = "dummy", fileext = ".nc")
Expand Down Expand Up @@ -80,6 +80,9 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE,
stop("'time_dim' must not be given if guess_time is TRUE")
}

levels <- get_char_levels(variables)
variables <- factorise(variables, levels)

## initialise variables
dims <- list() ## dimension variables created with nc_dim
if (missing(dim_factors)) {
Expand Down Expand Up @@ -242,15 +245,19 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE,
dim_name <- col
## strip trailing numbers, these indicate duplicate dimensions
dim_name <- sub("\\.[0-9]+$", "", dim_name)
dim_values <- unique(element[[col]])
if (is.factor(element[[col]])) {
dim_values <- levels(element[[col]])
} else {
dim_values <- unique(element[[col]])
}
if (dim_name %in% names(dims)) {
if (length(dim_values) != dims[[dim_name]]$len) {
stop(
"Two dimensions of name '", dim_name, "' have different lengths"
)
}
} else {
new_dim <- ncdim_def(dim_name, "", seq_along(unique(dim_values)) - 1)
new_dim <- ncdim_def(dim_name, "", seq_along(dim_values) - 1)
dims[[dim_name]] <- new_dim
if (!(class(dim_values) %in% c("numeric", "integer") &&
length(setdiff(as.integer(dim_values), dim_values)) == 0 &&
Expand Down Expand Up @@ -320,8 +327,8 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE,

for (name in names(vars)) {
if ((!(append || overwrite)) ||
(append && !(name %in% existing_vars)) ||
(overwrite && (name %in% existing_vars))) {
(append && !(name %in% existing_vars)) ||
(overwrite && (name %in% existing_vars))) {
if (!missing(verbose) && verbose) {
message(date(), " Writing ", name)
}
Expand Down Expand Up @@ -357,6 +364,7 @@ check_sparse_var <- function(x, coord_cols, value_column) {
setorderv(check, coord_cols)

all_values <- lapply(coord_cols, function(x) unique(check[[x]]))
names(all_values) <- coord_cols
all_combinations <- do.call(CJ, all_values)

## check if for all combinations of other calls the values of coord_cols
Expand All @@ -366,7 +374,12 @@ check_sparse_var <- function(x, coord_cols, value_column) {
all(.SD[, coord_cols, with = FALSE] == all_combinations)
), by = other_cols]

return(any(!all[["all_equal"]]))
## all_factors
all_factors <- vapply(coord_cols, function(x) {
length(setdiff(levels(all_values[[x]]), all_values[[x]])) == 0
}, logical(1))

return(any(!all[["all_equal"]]) || any(!all_factors))
}

##' Create a coordinate variable
Expand Down Expand Up @@ -401,8 +414,8 @@ create_coord_var <- function(name, dims, dim_factors, coord_dim, index_table,
for (loop_coord_dim in coord_dim) {
dim_index <- index_table[[loop_coord_dim]]
if (!((is.integer(dim_index) || is.numeric(dim_index)) &&
length(setdiff(as.integer(dim_index), dim_index)) == 0 &&
length(setdiff(seq_len(max(dim_index)), unique(dim_index))) == 0)) {
length(setdiff(as.integer(dim_index), dim_index)) == 0 &&
length(setdiff(seq_len(max(dim_index)), unique(dim_index))) == 0)) {
if (any(class(dim_index) == "factor")) {
dim_factors[[loop_coord_dim]] <- union(
dim_factors[[loop_coord_dim]], levels(dim_index)
Expand Down Expand Up @@ -449,3 +462,51 @@ create_coord_var <- function(name, dims, dim_factors, coord_dim, index_table,
dim = coord_index_dim, dim_factors = dim_factors
))
}

##' Get the factor levels of all character columns in data
##'
##' @param ... variable lists
##' @return a list with elements that represent the factor levels present in
##' character columns
##' @author Sebastian Funk
get_char_levels <- function(...) {
levels <- list()
for (variables in list(...)) {
## convert character strings to factors
data_frames <- names(variables)[
vapply(variables, is.data.frame, logical(1))
]
if (length(data_frames) > 0) {
common <- rbindlist(variables[data_frames], fill = TRUE)
char_cols <- colnames(common)[vapply(common, is.character, logical(1))]
for (col in char_cols) {
levels[[col]] <- union(levels[[col]], unique(na.omit(common[[col]])))
}
}
}
return(levels)
}

##' Convert character columns to factors in data
##'
##' @param levels factor levels, as a named list, each representing one column
##' @inheritParams bi_write
##' @return the \code{variables} argument with factorised columns
##' @author Sebastian Funk
factorise <- function(variables, levels) {
data_frames <- names(variables)[
vapply(variables, is.data.frame, logical(1))
]
if (length(data_frames) > 0) {
for (col in names(levels)) {
## convert character strings to factors
variables[data_frames] <- lapply(variables[data_frames], function(df) {
if (col %in% colnames(df)) {
df[[col]] <- factor(df[[col]], levels = levels[[col]])
}
return(df)
})
}
}
return(variables)
}
23 changes: 14 additions & 9 deletions R/libbi.R
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,18 @@ run.libbi <- function(x, client, proposal = c("model", "prior"), model, fix,
file_args <- intersect(names(args), file_types)
## assign file args to global options
for (arg in file_args) x$options[[arg]] <- get(arg)
list_args <- file_args[vapply(x$options[file_args], is.list, logical(1))]
if (length(list_args) > 0) {
levels <- do.call(get_char_levels, x$options[list_args])
x$options[list_args] <- lapply(x$options[list_args], factorise, levels)
}

if (x$run_flag && length(x$output_file_name) == 1 &&
file.exists(x$output_file_name)) {
file.exists(x$output_file_name)) {
added_options <- option_list(new_options)
init_file_given <-
("init" %in% file_args && !is.null(x$options[["init"]])) ||
"init-file" %in% names(added_options)
"init-file" %in% names(added_options)
init_np_given <- "init-np" %in% names(added_options)
init_given <- init_file_given || init_np_given
if (missing(chain)) { ## if chain not specified, only chain if no init
Expand All @@ -341,7 +346,7 @@ run.libbi <- function(x, client, proposal = c("model", "prior"), model, fix,
)
}
if ("target" %in% names(all_options) &&
all_options[["target"]] == "prediction") {
all_options[["target"]] == "prediction") {
read_init <- bi_read(x, type = c("param", "state", "obs"))
np_dims <- bi_dim_len(x$output_file_name, "np")
x$options[["nsamples"]] <- floor(np_dims / x$thin)
Expand Down Expand Up @@ -796,7 +801,7 @@ attach_data.libbi <- function(x, file, data, in_place = FALSE, append = FALSE,
if (length(coord_dims) > 0) {
for (coord_dim in names(coord_dims)) {
if (!is.null(x$coord_dims[[coord_dim]]) &&
any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) {
any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) {
warning(
"Given coord dimension ", coord_dim,
" will override a coord dimension of the same name in",
Expand Down Expand Up @@ -834,8 +839,8 @@ attach_data.libbi <- function(x, file, data, in_place = FALSE, append = FALSE,
}

if ((append || overwrite || "list" %in% class(data) ||
file %in% c("obs", "input")) &&
length(vars) > 0) {
file %in% c("obs", "input")) &&
length(vars) > 0) {
write_opts <- list(filename = target_file_name, variables = vars)
if (length(x$time_dim) == 0) {
write_opts[["guess_time"]] <- TRUE
Expand Down Expand Up @@ -1028,7 +1033,7 @@ read_libbi <- function(name, ...) {

for (option in pass_options) {
if (!(option %in% names(libbi_options)) &&
option %in% names(read_obj)) {
option %in% names(read_obj)) {
libbi_options[[option]] <- read_obj[[option]]
}
}
Expand Down Expand Up @@ -1247,7 +1252,7 @@ assert_files.libbi <- function(x, ...) {
stop("The libbi object does not contain an output file.")
} else {
if ("output" %in% names(x$timestamp) &&
x$timestamp[["output"]] < file.mtime(x$output_file_name)) {
x$timestamp[["output"]] < file.mtime(x$output_file_name)) {
stop(
"Output file ", x$output_file_name,
" has been modified since LibBi was run."
Expand All @@ -1259,7 +1264,7 @@ assert_files.libbi <- function(x, ...) {
file_type <- sub("-file$", "", file_option)
if (file.exists(x$options[[file_option]])) {
if (file_type %in% names(x$timestamp) &&
x$timestamp[[file_type]] < file.mtime(x$options[[file_option]])) {
x$timestamp[[file_type]] < file.mtime(x$options[[file_option]])) {
stop(
file_type, " file ", x$options[[file_option]],
" has been modified since LibBi was run. You can use",
Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/test-io.r
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,34 @@ test_that("saved and re-loaded objects are the same", {
expect_true(all(unlist(values1) - unlist(values2) < 1e-5))
})

test_that("IO works with character strings", {
test_char <- lapply(test_output, function(df) {
if ("a" %in% colnames(df)) df$a <- as.character(df$a)
return(df)
})
filename <- tempfile(fileext = ".nc")
out <- bi_write(filename, test_char, guess_time = TRUE)
test_output2 <- bi_read(
filename, coord_dims = out$coord_dims, dims = out$dims
)
lists <- vapply(test_output, is.list, logical(1))
list_names <- names(lists[lists])
list_cols <- lapply(list_names, function(x) {
setdiff(colnames(test_output[[x]]), "value")
})
names(list_cols) <- list_names

for (name in list_names) {
setorderv(test_output[[name]], list_cols[[name]])
setorderv(test_output2[[name]], list_cols[[name]])
}

values1 <- lapply(test_output[list_names], function(x) x[["value"]])
values2 <- lapply(test_output2[list_names], function(x) x[["value"]])

expect_true(all(unlist(values1) - unlist(values2) < 1e-5))
})

test_that("basic I/O functions work", {
bi <- attach_data(bi, file = "init", test_output[c("e", "m")])
bi <- attach_data(bi, "obs", test_output_sparse[c("M", "e")])
Expand Down

0 comments on commit 0cb0e63

Please sign in to comment.