Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subsetting psis objects #110

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ S3method(relative_eff,array)
S3method(relative_eff,default)
S3method(relative_eff,matrix)
S3method(relative_eff,psis)
S3method(subset,psis)
S3method(update,psis_loo_ss)
S3method(waic,"function")
S3method(waic,array)
Expand Down Expand Up @@ -101,6 +102,7 @@ export(psis_n_eff_values)
export(psislw)
export(relative_eff)
export(stacking_weights)
export(subset.psis)
export(waic)
export(waic.array)
export(waic.function)
Expand Down
55 changes: 51 additions & 4 deletions R/psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@
#' uw <- weights(psis_result, log=FALSE, normalize = FALSE) # unnormalized weights
#'
#'
#'
psis <- function(log_ratios, ...) UseMethod("psis")


#' @export
#' @templateVar fn psis
#' @template array
Expand All @@ -102,6 +102,7 @@ psis.array <-
do_psis(log_ratios, r_eff = r_eff, cores = cores)
}


#' @export
#' @templateVar fn psis
#' @template matrix
Expand All @@ -117,6 +118,7 @@ psis.matrix <-
do_psis(log_ratios, r_eff = r_eff, cores = cores)
}


#' @export
#' @templateVar fn psis
#' @template vector
Expand All @@ -129,12 +131,12 @@ psis.default <-
psis.matrix(log_ratios, r_eff = r_eff, cores = 1)
}


#' @rdname psis
#' @export
#' @export weights.psis
#' @method weights psis
#' @param object For the `weights()` method, an object returned by `psis()` (a
#' list with class `"psis"`).
#' @param object,x An object returned by `psis()`.
#' @param log For the `weights()` method, should the weights be returned on
#' the log scale? Defaults to `TRUE`.
#' @param normalize For the `weights()` method, should the weights be
Expand Down Expand Up @@ -165,19 +167,64 @@ weights.psis <-
}


# Subset a psis object without breaking it
#
#' @rdname psis
#' @export
#' @export subset.psis
#' @method subset psis
#' @param subset For the `subset()` method, a vector indicating which
#' observations (columns of weights) to keep. Can be a logical vector of
#' length `ncol(x)` (for a psis object `x`) or a shorter integer vector
#' containing only the indexes to keep.
#'
#' @return The `subset()` returns a `"psis"` object. It is the same as the input
#' but without the contents corresponding to the unselected indexes.
#'
subset.psis <- function(x, subset, ...) {
if (anyNA(subset)) {
stop("NAs not allowed in subset.", call. = FALSE)
}
if (is.logical(subset) || all(subset %in% c(0,1))) {
stopifnot(length(subset) == dim(x)[2])
subset <- which(as.logical(subset))
} else {
stopifnot(length(subset) <= dim(x)[2],
all(subset == as.integer(subset)))
subset <- as.integer(subset)
}

x$log_weights <- x$log_weights[, subset, drop=FALSE]
x$diagnostics$pareto_k <- x$diagnostics$pareto_k[subset]
x$diagnostics$n_eff <- x$diagnostics$n_eff[subset]

structure(
.Data = x,
class = class(x),
dims = c(dim(x)[1], length(subset)),
norm_const_log = attr(x, "norm_const_log")[subset],
tail_len = attr(x, "tail_len")[subset],
r_eff = attr(x, "r_eff")[subset],
subset = subset
)
}


#' @rdname psis
#' @export
dim.psis <- function(x) {
attr(x, "dims")
}


#' @rdname psis
#' @export
#' @param x For `is.psis()`, an object to check.
is.psis <- function(x) {
inherits(x, "psis") && is.list(x)
}



# internal ----------------------------------------------------------------

#' Structure the object returned by the psis methods
Expand Down
18 changes: 14 additions & 4 deletions man/psis.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions tests/testthat/test_psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,34 @@ test_that("weights method returns correct output", {
})


test_that("subset method works correctly", {
dims <- dim(psis1)
a1 <- subset(psis1, subset = rep_len(c(TRUE, FALSE), dims[2]))
a2 <- subset(psis1, subset = seq(1, dims[2], by = 2))
expect_identical(a1, a2) # logical subsetting same as specifying indexes

a3 <- subset(psis1, subset = c(1, 4, 20))
expect_equal(a3$log_weights, psis1$log_weights[, c(1, 4, 20), drop=FALSE])
expect_equal(attr(a3, "tail_len"), attr(psis1, "tail_len")[c(1, 4, 20)])
expect_equal(attr(a3, "subset"), c(1, 4, 20))

expect_error(
subset(psis1, subset = c(TRUE, FALSE)),
"length(subset) == dim(x)[2] is not TRUE",
fixed = TRUE
)
expect_error(
subset(psis1, subset = seq_len(dim(psis1)[2] + 1)),
"length(subset) <= dim(x)[2] is not TRUE",
fixed = TRUE
)
expect_error(
subset(psis1, subset = c(1, NA, 3)),
"NAs not allowed in subset"
)
})


test_that("psis_n_eff methods works properly", {
w <- weights(psis1, normalize = TRUE, log = FALSE)
expect_equal(psis_n_eff.default(w[, 1], r_eff = 1), 1 / sum(w[, 1]^2))
Expand Down