Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

turn character vectors into factors #54

Merged
merged 5 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# rbi (development version)

* work with character columns in sparse input data

# rbi 1.0.0

* add github actions for package checking and testing
Expand Down
16 changes: 10 additions & 6 deletions R/bi_read.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold,

for (coord_dim in names(coord_dims)) {
if (!is.null(x$coord_dims[[coord_dim]]) &&
any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) {
any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) {
warning(
"Given coord dimension ", coord_dim, " will override a coord ",
"dimension of the same name in passed libbi object"
Expand Down Expand Up @@ -152,7 +152,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold,

## cache
if ("libbi" %in% class(x) && x$use_cache &&
(missing(file) || file == "output")) {
(missing(file) || file == "output")) {
if (clear_cache) {
x$.cache$data <- NULL
x$.cache$thin <- NULL
Expand Down Expand Up @@ -295,19 +295,23 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold,
rownames(mav) <- seq_len(nrow(mav))

if ("libbi" %in% class(x) && length(x$coord_dims) > 0 &&
var_name %in% names(x$coord_dims) && "coord" %in% colnames(mav)) {
var_name %in% names(x$coord_dims) && "coord" %in% colnames(mav)) {
setnames(mav, "coord", x$coord_dims[[var_name]])
}
if ("libbi" %in% class(x) && length(x$time_dim) == 1 &&
"time" %in% colnames(mav)) {
"time" %in% colnames(mav)) {
setnames(mav, "time", x$time_dim)
}

for (col in colnames(mav)) {
## strip trailing numbers, these indicate duplicate dimensions
dim_col <- sub("\\.[0-9]+$", "", col)
if (!missing(dims) && !is.null(dims) && dim_col %in% names(dims)) {
mav[[col]] <- factor(mav[[col]], labels = dims[[dim_col]])
mav[[col]] <- factor(
mav[[col]] - 1L,
levels = seq_along(dims[[dim_col]]) - 1L,
labels = dims[[dim_col]]
)
} else if (dim_col %in% var_dims[["other"]][[var_name]]) {
mav[[col]] <- mav[[col]] - 1
}
Expand Down Expand Up @@ -339,7 +343,7 @@ bi_read <- function(x, vars, dims, model, type, file, missval_threshold,
if (any(class(x) %in% c("character", "libbi"))) nc_close(nc)

if ("libbi" %in% class(x) && x$use_cache &&
(missing(file) || file == "output")) {
(missing(file) || file == "output")) {
if (is.null(x$.cache[["data"]])) {
x$.cache$data <- list()
}
Expand Down
77 changes: 69 additions & 8 deletions R/bi_write.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#' @return A list of the time and coord dims, and factors in extra dimensions,
#' if any
#' @importFrom ncdf4 nc_close ncdim_def ncvar_def nc_create ncvar_put ncvar_add
#' @importFrom data.table data.table copy
#' @importFrom data.table data.table copy rbindlist
#' @importFrom reshape2 melt
#' @examples
#' filename <- tempfile(pattern = "dummy", fileext = ".nc")
Expand Down Expand Up @@ -80,6 +80,9 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE,
stop("'time_dim' must not be given if guess_time is TRUE")
}

levels <- get_char_levels(variables)
variables <- factorise(variables, levels)

## initialise variables
dims <- list() ## dimension variables created with nc_dim
if (missing(dim_factors)) {
Expand Down Expand Up @@ -242,15 +245,19 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE,
dim_name <- col
## strip trailing numbers, these indicate duplicate dimensions
dim_name <- sub("\\.[0-9]+$", "", dim_name)
dim_values <- unique(element[[col]])
if (is.factor(element[[col]])) {
dim_values <- levels(element[[col]])
} else {
dim_values <- unique(element[[col]])
}
if (dim_name %in% names(dims)) {
if (length(dim_values) != dims[[dim_name]]$len) {
stop(
"Two dimensions of name '", dim_name, "' have different lengths"
)
}
} else {
new_dim <- ncdim_def(dim_name, "", seq_along(unique(dim_values)) - 1)
new_dim <- ncdim_def(dim_name, "", seq_along(dim_values) - 1)
dims[[dim_name]] <- new_dim
if (!(class(dim_values) %in% c("numeric", "integer") &&
length(setdiff(as.integer(dim_values), dim_values)) == 0 &&
Expand Down Expand Up @@ -320,8 +327,8 @@ bi_write <- function(filename, variables, append = FALSE, overwrite = FALSE,

for (name in names(vars)) {
if ((!(append || overwrite)) ||
(append && !(name %in% existing_vars)) ||
(overwrite && (name %in% existing_vars))) {
(append && !(name %in% existing_vars)) ||
(overwrite && (name %in% existing_vars))) {
if (!missing(verbose) && verbose) {
message(date(), " Writing ", name)
}
Expand Down Expand Up @@ -357,6 +364,7 @@ check_sparse_var <- function(x, coord_cols, value_column) {
setorderv(check, coord_cols)

all_values <- lapply(coord_cols, function(x) unique(check[[x]]))
names(all_values) <- coord_cols
all_combinations <- do.call(CJ, all_values)

## check if for all combinations of other calls the values of coord_cols
Expand All @@ -366,7 +374,12 @@ check_sparse_var <- function(x, coord_cols, value_column) {
all(.SD[, coord_cols, with = FALSE] == all_combinations)
), by = other_cols]

return(any(!all[["all_equal"]]))
## all_factors
all_factors <- vapply(coord_cols, function(x) {
length(setdiff(levels(all_values[[x]]), all_values[[x]])) == 0
}, logical(1))

return(any(!all[["all_equal"]]) || any(!all_factors))
}

##' Create a coordinate variable
Expand Down Expand Up @@ -401,8 +414,8 @@ create_coord_var <- function(name, dims, dim_factors, coord_dim, index_table,
for (loop_coord_dim in coord_dim) {
dim_index <- index_table[[loop_coord_dim]]
if (!((is.integer(dim_index) || is.numeric(dim_index)) &&
length(setdiff(as.integer(dim_index), dim_index)) == 0 &&
length(setdiff(seq_len(max(dim_index)), unique(dim_index))) == 0)) {
length(setdiff(as.integer(dim_index), dim_index)) == 0 &&
length(setdiff(seq_len(max(dim_index)), unique(dim_index))) == 0)) {
if (any(class(dim_index) == "factor")) {
dim_factors[[loop_coord_dim]] <- union(
dim_factors[[loop_coord_dim]], levels(dim_index)
Expand Down Expand Up @@ -449,3 +462,51 @@ create_coord_var <- function(name, dims, dim_factors, coord_dim, index_table,
dim = coord_index_dim, dim_factors = dim_factors
))
}

##' Get the factor levels of all character columns in data
##'
##' @param ... variable lists
##' @return a list with elements that represent the factor levels present in
##' character columns
##' @author Sebastian Funk
get_char_levels <- function(...) {
levels <- list()
for (variables in list(...)) {
## convert character strings to factors
data_frames <- names(variables)[
vapply(variables, is.data.frame, logical(1))
]
if (length(data_frames) > 0) {
common <- rbindlist(variables[data_frames], fill = TRUE)
char_cols <- colnames(common)[vapply(common, is.character, logical(1))]
for (col in char_cols) {
levels[[col]] <- union(levels[[col]], unique(na.omit(common[[col]])))
}
}
}
return(levels)
}

##' Convert character columns to factors in data
##'
##' @param levels factor levels, as a named list, each representing one column
##' @inheritParams bi_write
##' @return the \code{variables} argument with factorised columns
##' @author Sebastian Funk
factorise <- function(variables, levels) {
data_frames <- names(variables)[
vapply(variables, is.data.frame, logical(1))
]
if (length(data_frames) > 0) {
for (col in names(levels)) {
## convert character strings to factors
variables[data_frames] <- lapply(variables[data_frames], function(df) {
if (col %in% colnames(df)) {
df[[col]] <- factor(df[[col]], levels = levels[[col]])
}
return(df)
})
}
}
return(variables)
}
23 changes: 14 additions & 9 deletions R/libbi.R
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,18 @@ run.libbi <- function(x, client, proposal = c("model", "prior"), model, fix,
file_args <- intersect(names(args), file_types)
## assign file args to global options
for (arg in file_args) x$options[[arg]] <- get(arg)
list_args <- file_args[vapply(x$options[file_args], is.list, logical(1))]
if (length(list_args) > 0) {
levels <- do.call(get_char_levels, x$options[list_args])
x$options[list_args] <- lapply(x$options[list_args], factorise, levels)
}

if (x$run_flag && length(x$output_file_name) == 1 &&
file.exists(x$output_file_name)) {
file.exists(x$output_file_name)) {
added_options <- option_list(new_options)
init_file_given <-
("init" %in% file_args && !is.null(x$options[["init"]])) ||
"init-file" %in% names(added_options)
"init-file" %in% names(added_options)
init_np_given <- "init-np" %in% names(added_options)
init_given <- init_file_given || init_np_given
if (missing(chain)) { ## if chain not specified, only chain if no init
Expand All @@ -341,7 +346,7 @@ run.libbi <- function(x, client, proposal = c("model", "prior"), model, fix,
)
}
if ("target" %in% names(all_options) &&
all_options[["target"]] == "prediction") {
all_options[["target"]] == "prediction") {
read_init <- bi_read(x, type = c("param", "state", "obs"))
np_dims <- bi_dim_len(x$output_file_name, "np")
x$options[["nsamples"]] <- floor(np_dims / x$thin)
Expand Down Expand Up @@ -796,7 +801,7 @@ attach_data.libbi <- function(x, file, data, in_place = FALSE, append = FALSE,
if (length(coord_dims) > 0) {
for (coord_dim in names(coord_dims)) {
if (!is.null(x$coord_dims[[coord_dim]]) &&
any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) {
any(x$coord_dims[[coord_dim]] != coord_dims[[coord_dim]])) {
warning(
"Given coord dimension ", coord_dim,
" will override a coord dimension of the same name in",
Expand Down Expand Up @@ -834,8 +839,8 @@ attach_data.libbi <- function(x, file, data, in_place = FALSE, append = FALSE,
}

if ((append || overwrite || "list" %in% class(data) ||
file %in% c("obs", "input")) &&
length(vars) > 0) {
file %in% c("obs", "input")) &&
length(vars) > 0) {
write_opts <- list(filename = target_file_name, variables = vars)
if (length(x$time_dim) == 0) {
write_opts[["guess_time"]] <- TRUE
Expand Down Expand Up @@ -1028,7 +1033,7 @@ read_libbi <- function(name, ...) {

for (option in pass_options) {
if (!(option %in% names(libbi_options)) &&
option %in% names(read_obj)) {
option %in% names(read_obj)) {
libbi_options[[option]] <- read_obj[[option]]
}
}
Expand Down Expand Up @@ -1247,7 +1252,7 @@ assert_files.libbi <- function(x, ...) {
stop("The libbi object does not contain an output file.")
} else {
if ("output" %in% names(x$timestamp) &&
x$timestamp[["output"]] < file.mtime(x$output_file_name)) {
x$timestamp[["output"]] < file.mtime(x$output_file_name)) {
stop(
"Output file ", x$output_file_name,
" has been modified since LibBi was run."
Expand All @@ -1259,7 +1264,7 @@ assert_files.libbi <- function(x, ...) {
file_type <- sub("-file$", "", file_option)
if (file.exists(x$options[[file_option]])) {
if (file_type %in% names(x$timestamp) &&
x$timestamp[[file_type]] < file.mtime(x$options[[file_option]])) {
x$timestamp[[file_type]] < file.mtime(x$options[[file_option]])) {
stop(
file_type, " file ", x$options[[file_option]],
" has been modified since LibBi was run. You can use",
Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/test-io.r
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,34 @@ test_that("saved and re-loaded objects are the same", {
expect_true(all(unlist(values1) - unlist(values2) < 1e-5))
})

test_that("IO works with character strings", {
test_char <- lapply(test_output, function(df) {
if ("a" %in% colnames(df)) df$a <- as.character(df$a)
return(df)
})
filename <- tempfile(fileext = ".nc")
out <- bi_write(filename, test_char, guess_time = TRUE)
test_output2 <- bi_read(
filename, coord_dims = out$coord_dims, dims = out$dims
)
lists <- vapply(test_output, is.list, logical(1))
list_names <- names(lists[lists])
list_cols <- lapply(list_names, function(x) {
setdiff(colnames(test_output[[x]]), "value")
})
names(list_cols) <- list_names

for (name in list_names) {
setorderv(test_output[[name]], list_cols[[name]])
setorderv(test_output2[[name]], list_cols[[name]])
}

values1 <- lapply(test_output[list_names], function(x) x[["value"]])
values2 <- lapply(test_output2[list_names], function(x) x[["value"]])

expect_true(all(unlist(values1) - unlist(values2) < 1e-5))
})

test_that("basic I/O functions work", {
bi <- attach_data(bi, file = "init", test_output[c("e", "m")])
bi <- attach_data(bi, "obs", test_output_sparse[c("M", "e")])
Expand Down
Loading