diff --git a/NEWS.md b/NEWS.md index 69fb362..fc68c69 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # rbi (development version) +* work with character columns in sparse input data + # rbi 1.0.0 * add github actions for package checking and testing diff --git a/R/bi_read.R b/R/bi_read.R index 8e595f4..447423b 100644 --- a/R/bi_read.R +++ b/R/bi_read.R @@ -84,7 +84,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold, for (coord_dim in names(coord_dims)) { if (!is.null(x$coord_dims[[coord_dim]]) && - any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) { + any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) { warning( "Given coord dimension ", coord_dim, " will override a coord ", "dimension of the same name in passed libbi object" @@ -152,7 +152,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold, ## cache if ("libbi" %in% class(x) && x$use_cache && - (missing(file) || file == "output")) { + (missing(file) || file == "output")) { if (clear_cache) { x$.cache$data <- NULL x$.cache$thin <- NULL @@ -295,11 +295,11 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold, rownames(mav) <- seq_len(nrow(mav)) if ("libbi" %in% class(x) && length(x$coord_dims) > 0 && - var_name %in% names(x$coord_dims) && "coord" %in% colnames(mav)) { + var_name %in% names(x$coord_dims) && "coord" %in% colnames(mav)) { setnames(mav, "coord", x$coord_dims[[var_name]]) } if ("libbi" %in% class(x) && length(x$time_dim) == 1 && - "time" %in% colnames(mav)) { + "time" %in% colnames(mav)) { setnames(mav, "time", x$time_dim) } @@ -307,7 +307,11 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold, ## strip trailing numbers, these indicate duplicate dimensions dim_col <- sub("\\.[0-9]+$", "", col) if (!missing(dims) && !is.null(dims) && dim_col %in% names(dims)) { - mav[[col]] <- factor(mav[[col]], labels = dims[[dim_col]]) + mav[[col]] <- factor( + mav[[col]] - 1L, + levels = seq_along(dims[[dim_col]]) - 1L, + labels = dims[[dim_col]] + ) } else if (dim_col %in% var_dims[["other"]][[var_name]]) { mav[[col]] <- mav[[col]] - 1 } @@ -339,7 +343,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold, if (any(class(x) %in% c("character", "libbi"))) nc_close(nc) if ("libbi" %in% class(x) && x$use_cache && - (missing(file) || file == "output")) { + (missing(file) || file == "output")) { if (is.null(x$.cache[["data"]])) { x$.cache$data <- list() } diff --git a/R/bi_write.R b/R/bi_write.R index d2c8105..19fa572 100644 --- a/R/bi_write.R +++ b/R/bi_write.R @@ -43,7 +43,7 @@ #' @return A list of the time and coord dims, and factors in extra dimensions, #' if any #' @importFrom ncdf4 nc_close ncdim_def ncvar_def nc_create ncvar_put ncvar_add -#' @importFrom data.table data.table copy +#' @importFrom data.table data.table copy rbindlist #' @importFrom reshape2 melt #' @examples #' filename <- tempfile(pattern = "dummy", fileext = ".nc") @@ -80,6 +80,9 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE, stop("'time_dim' must not be given if guess_time is TRUE") } + levels <- get_char_levels(variables) + variables <- factorise(variables, levels) + ## initialise variables dims <- list() ## dimension variables created with nc_dim if (missing(dim_factors)) { @@ -242,7 +245,11 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE, dim_name <- col ## strip trailing numbers, these indicate duplicate dimensions dim_name <- sub("\\.[0-9]+$", "", dim_name) - dim_values <- unique(element[[col]]) + if (is.factor(element[[col]])) { + dim_values <- levels(element[[col]]) + } else { + dim_values <- unique(element[[col]]) + } if (dim_name %in% names(dims)) { if (length(dim_values) != dims[[dim_name]]$len) { stop( @@ -250,7 +257,7 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE, ) } } else { - new_dim <- ncdim_def(dim_name, "", seq_along(unique(dim_values)) - 1) + new_dim <- ncdim_def(dim_name, "", seq_along(dim_values) - 1) dims[[dim_name]] <- new_dim if (!(class(dim_values) %in% c("numeric", "integer") && length(setdiff(as.integer(dim_values), dim_values)) == 0 && @@ -320,8 +327,8 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE, for (name in names(vars)) { if ((!(append || overwrite)) || - (append && !(name %in% existing_vars)) || - (overwrite && (name %in% existing_vars))) { + (append && !(name %in% existing_vars)) || + (overwrite && (name %in% existing_vars))) { if (!missing(verbose) && verbose) { message(date(), " Writing ", name) } @@ -357,6 +364,7 @@ check_sparse_var <- function(x, coord_cols, value_column) { setorderv(check, coord_cols) all_values <- lapply(coord_cols, function(x) unique(check[[x]])) + names(all_values) <- coord_cols all_combinations <- do.call(CJ, all_values) ## check if for all combinations of other calls the values of coord_cols @@ -366,7 +374,12 @@ check_sparse_var <- function(x, coord_cols, value_column) { all(.SD[, coord_cols, with = FALSE] == all_combinations) ), by = other_cols] - return(any(!all[["all_equal"]])) + ## all_factors + all_factors <- vapply(coord_cols, function(x) { + length(setdiff(levels(all_values[[x]]), all_values[[x]])) == 0 + }, logical(1)) + + return(any(!all[["all_equal"]]) || any(!all_factors)) } ##' Create a coordinate variable @@ -401,8 +414,8 @@ create_coord_var <- function(name, dims, dim_factors, coord_dim, index_table, for (loop_coord_dim in coord_dim) { dim_index <- index_table[[loop_coord_dim]] if (!((is.integer(dim_index) || is.numeric(dim_index)) && - length(setdiff(as.integer(dim_index), dim_index)) == 0 && - length(setdiff(seq_len(max(dim_index)), unique(dim_index))) == 0)) { + length(setdiff(as.integer(dim_index), dim_index)) == 0 && + length(setdiff(seq_len(max(dim_index)), unique(dim_index))) == 0)) { if (any(class(dim_index) == "factor")) { dim_factors[[loop_coord_dim]] <- union( dim_factors[[loop_coord_dim]], levels(dim_index) @@ -449,3 +462,51 @@ create_coord_var <- function(name, dims, dim_factors, coord_dim, index_table, dim = coord_index_dim, dim_factors = dim_factors )) } + +##' Get the factor levels of all character columns in data +##' +##' @param ... variable lists +##' @return a list with elements that represent the factor levels present in +##' character columns +##' @author Sebastian Funk +get_char_levels <- function(...) { + levels <- list() + for (variables in list(...)) { + ## convert character strings to factors + data_frames <- names(variables)[ + vapply(variables, is.data.frame, logical(1)) + ] + if (length(data_frames) > 0) { + common <- rbindlist(variables[data_frames], fill = TRUE) + char_cols <- colnames(common)[vapply(common, is.character, logical(1))] + for (col in char_cols) { + levels[[col]] <- union(levels[[col]], unique(na.omit(common[[col]]))) + } + } + } + return(levels) +} + +##' Convert character columns to factors in data +##' +##' @param levels factor levels, as a named list, each representing one column +##' @inheritParams bi_write +##' @return the \code{variables} argument with factorised columns +##' @author Sebastian Funk +factorise <- function(variables, levels) { + data_frames <- names(variables)[ + vapply(variables, is.data.frame, logical(1)) + ] + if (length(data_frames) > 0) { + for (col in names(levels)) { + ## convert character strings to factors + variables[data_frames] <- lapply(variables[data_frames], function(df) { + if (col %in% colnames(df)) { + df[[col]] <- factor(df[[col]], levels = levels[[col]]) + } + return(df) + }) + } + } + return(variables) +} diff --git a/R/libbi.R b/R/libbi.R index 1e9280d..9ddb06f 100644 --- a/R/libbi.R +++ b/R/libbi.R @@ -314,13 +314,18 @@ run.libbi <- function(x, client, proposal = c("model", "prior"), model, fix, file_args <- intersect(names(args), file_types) ## assign file args to global options for (arg in file_args) x$options[[arg]] <- get(arg) + list_args <- file_args[vapply(x$options[file_args], is.list, logical(1))] + if (length(list_args) > 0) { + levels <- do.call(get_char_levels, x$options[list_args]) + x$options[list_args] <- lapply(x$options[list_args], factorise, levels) + } if (x$run_flag && length(x$output_file_name) == 1 && - file.exists(x$output_file_name)) { + file.exists(x$output_file_name)) { added_options <- option_list(new_options) init_file_given <- ("init" %in% file_args && !is.null(x$options[["init"]])) || - "init-file" %in% names(added_options) + "init-file" %in% names(added_options) init_np_given <- "init-np" %in% names(added_options) init_given <- init_file_given || init_np_given if (missing(chain)) { ## if chain not specified, only chain if no init @@ -341,7 +346,7 @@ run.libbi <- function(x, client, proposal = c("model", "prior"), model, fix, ) } if ("target" %in% names(all_options) && - all_options[["target"]] == "prediction") { + all_options[["target"]] == "prediction") { read_init <- bi_read(x, type = c("param", "state", "obs")) np_dims <- bi_dim_len(x$output_file_name, "np") x$options[["nsamples"]] <- floor(np_dims / x$thin) @@ -796,7 +801,7 @@ attach_data.libbi <- function(x, file, data, in_place = FALSE, append = FALSE, if (length(coord_dims) > 0) { for (coord_dim in names(coord_dims)) { if (!is.null(x$coord_dims[[coord_dim]]) && - any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) { + any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) { warning( "Given coord dimension ", coord_dim, " will override a coord dimension of the same name in", @@ -834,8 +839,8 @@ attach_data.libbi <- function(x, file, data, in_place = FALSE, append = FALSE, } if ((append || overwrite || "list" %in% class(data) || - file %in% c("obs", "input")) && - length(vars) > 0) { + file %in% c("obs", "input")) && + length(vars) > 0) { write_opts <- list(filename = target_file_name, variables = vars) if (length(x$time_dim) == 0) { write_opts[["guess_time"]] <- TRUE @@ -1028,7 +1033,7 @@ read_libbi <- function(name, ...) { for (option in pass_options) { if (!(option %in% names(libbi_options)) && - option %in% names(read_obj)) { + option %in% names(read_obj)) { libbi_options[[option]] <- read_obj[[option]] } } @@ -1247,7 +1252,7 @@ assert_files.libbi <- function(x, ...) { stop("The libbi object does not contain an output file.") } else { if ("output" %in% names(x$timestamp) && - x$timestamp[["output"]] < file.mtime(x$output_file_name)) { + x$timestamp[["output"]] < file.mtime(x$output_file_name)) { stop( "Output file ", x$output_file_name, " has been modified since LibBi was run." @@ -1259,7 +1264,7 @@ assert_files.libbi <- function(x, ...) { file_type <- sub("-file$", "", file_option) if (file.exists(x$options[[file_option]])) { if (file_type %in% names(x$timestamp) && - x$timestamp[[file_type]] < file.mtime(x$options[[file_option]])) { + x$timestamp[[file_type]] < file.mtime(x$options[[file_option]])) { stop( file_type, " file ", x$options[[file_option]], " has been modified since LibBi was run. You can use", diff --git a/tests/testthat/test-io.r b/tests/testthat/test-io.r index 5ba139f..d03229e 100644 --- a/tests/testthat/test-io.r +++ b/tests/testthat/test-io.r @@ -149,6 +149,34 @@ test_that("saved and re-loaded objects are the same", { expect_true(all(unlist(values1) - unlist(values2) < 1e-5)) }) +test_that("IO works with character strings", { + test_char <- lapply(test_output, function(df) { + if ("a" %in% colnames(df)) df$a <- as.character(df$a) + return(df) + }) + filename <- tempfile(fileext = ".nc") + out <- bi_write(filename, test_char, guess_time = TRUE) + test_output2 <- bi_read( + filename, coord_dims = out$coord_dims, dims = out$dims + ) + lists <- vapply(test_output, is.list, logical(1)) + list_names <- names(lists[lists]) + list_cols <- lapply(list_names, function(x) { + setdiff(colnames(test_output[[x]]), "value") + }) + names(list_cols) <- list_names + + for (name in list_names) { + setorderv(test_output[[name]], list_cols[[name]]) + setorderv(test_output2[[name]], list_cols[[name]]) + } + + values1 <- lapply(test_output[list_names], function(x) x[["value"]]) + values2 <- lapply(test_output2[list_names], function(x) x[["value"]]) + + expect_true(all(unlist(values1) - unlist(values2) < 1e-5)) +}) + test_that("basic I/O functions work", { bi <- attach_data(bi, file = "init", test_output[c("e", "m")]) bi <- attach_data(bi, "obs", test_output_sparse[c("M", "e")])