Skip to content

Commit

Permalink
support matrix columns in nest/unnest_rvars, closes #316
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay committed Sep 3, 2023
1 parent 4b49bbd commit b9f8517
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# tidybayes (development version)

Buf fixes:

* Support for matrix columns in `nest_rvars()` and `unnest_rvars()`. (#316)


# tidybayes 3.0.6

Deprecations:
Expand Down
5 changes: 1 addition & 4 deletions R/nest_rvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@ unnest_rvars = function(data) {
# convert from draws_df to plain data.frame to avoid
# warning about meta-data being dropped
class(draws_df) = "data.frame"
# convert from tibble to plain data.frame to fix
# incorrect binding in cbind() in R < 4
class(constants) = "data.frame"
cbind(constants, draws_df)
vctrs::vec_cbind(constants, draws_df)
}))

group_by_at(out, groups_)
Expand Down
31 changes: 25 additions & 6 deletions tests/testthat/test.nest_rvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ test_that("nest_rvars works", {

draws_df = RankCorr_s %>%
spread_draws(b[i,..], tau[i]) %>%
rename(

) %>%
select(i,
`b[1,1]` = b.1,
`b[1,2]` = b.2,
Expand Down Expand Up @@ -85,9 +82,6 @@ test_that("missing / NA .chain and .iteration columns work", {

draws_df = RankCorr_s %>%
spread_draws(b[i,..], tau[i]) %>%
rename(

) %>%
select(i,
`b[1,1]` = b.1,
`b[1,2]` = b.2,
Expand All @@ -99,3 +93,28 @@ test_that("missing / NA .chain and .iteration columns work", {
expect_equal(unnest_rvars(rvar_df), draws_df)
expect_equal(nest_rvars(draws_df), group_by(rvar_df, i))
})


# matrix columns ----------------------------------------------------------

test_that("matrix columns work", {
rvar_df = tibble::tibble(
x = 1:3,
m = matrix(1:9, 3, 3),
y = rvar(array(1:12, dim = c(4, 3)))
) %>%
group_by(x, m)

draws_df = tibble::tibble(
x = rep(1:3, each = 4),
m = matrix(rep(1:9, each = 4), 12, 3),
y = 1:12,
.chain = 1L,
.iteration = rep(1:4, 3),
.draw = rep(1:4, 3)
) %>%
group_by(x, m)

expect_equal(unnest_rvars(rvar_df), draws_df)
expect_equal(nest_rvars(draws_df), rvar_df)
})

0 comments on commit b9f8517

Please sign in to comment.