Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increasing efficiency of add_epred_draws for large number of predictors? #314

Open
petermacp opened this issue Aug 23, 2023 · 3 comments
Open

Comments

@petermacp
Copy link

I am exploring working out a workflow for using add_epred_draws() and subsequent calculations when the number of potential predictors is very large. Appreciate the answer might be "just use a smaller prediction matrix" 🤣, but interested to see what is possible.

Following along with this example from Andrew Heiss, I have constructed a brms regression model, that fits well:

library(tidyverse)
library(brms)
library(tidybayes)
library(duckdb)

m3 <- brm(
  bf(choice_alt ~ 0 + (duration + numtabs + reduction +
       passon + adverseeffects + followup + cost + oo) *
       (diseaserisk + age_z + sex + reading + tb_contact) + 
       (1 | ID | pid)),
  data = m_data,
  family = categorical(refcat = "0"),
  prior = c(
    prior(normal(0, 3), class = b, dpar = mu1),
    prior(normal(0, 3), class = b, dpar = mu2),
    prior(normal(0, 3), class = b, dpar = mu3),
    prior(exponential(1), class = sd, dpar = mu1),
    prior(exponential(1), class = sd, dpar = mu2),
    prior(exponential(1), class = sd, dpar = mu3),
    prior(lkj(1), class = cor)
  ),
  chains = 4, cores = 4, iter = 2000, seed = 1234,
  backend = "cmdstanr", threads = threading(2)
)

We make a very large prediction matrix, comprising 5 million rows, with all combinations of predictors (just for this example).

nd3_matrix <-
  expand.grid(
    diseaserisk = unique(m_data$diseaserisk),
    duration = unique(m_data$duration),
    numtabs = unique(m_data$numtabs),
    reduction = unique(m_data$reduction),
    passon = unique(m_data$passon),
    adverseeffects = unique(m_data$adverseeffects),
    followup = unique(m_data$followup),
    cost = unique(m_data$cost),
    oo = unique(m_data$oo),
    sex = unique(m_data$sex),
    reading = unique(m_data$reading),
    tb_contact = unique(m_data$tb_contact),
    age_z = c(-1.5,-1,-0.5,0,0.5,1,1.5,2)
  )

Of course, when we try to add_epred_draws() using nd3_matrix in the newdata= argument, we very quickly run out of memory.

So instead, I wondered if it would be possible split nd3_matrix into more manageable chunks, and write to a database to allow more efficient post-processing, like this:

# Create a temporary duckdb database
con <- dbConnect(duckdb::duckdb(dbdir = "preds_m3.duckdb"))

#make a function to write a table to database for each `nest_id`

write_preds <- function(nest_value, ndraws){
  
  temp <- nd3_matrix %>%
  group_nest(sex, reading, tb_contact, age_z) %>%
  mutate(nest_id = row_number()) %>%
  unnest() %>%
  ungroup() %>%
  filter(nest_id==nest_value) %>%
  add_epred_draws(object=m3, re_formula = NA, ndraws = ndraws) %>%
  filter(.category == 0) %>% 
  mutate(.epred = 1 - .epred) 
  
  dbWriteTable(con, paste0("processed_data", "_", nest_value), temp, append = TRUE)
  
}

#Now run for all groups, and save results to database
nest_ids <- nd3_matrix %>%
  group_nest(sex, reading, tb_contact, age_z) %>%
  mutate(nest_id = row_number()) %>%
  unnest() %>%
  ungroup() %>%
  distinct(nest_id) %>%
  pull(nest_id)

walk(nest_ids, .f = \(x) write_preds(nest_value = x, ndraws = 100)) #still only 100 draws... but

# Disconnect from the DuckDB database
dbDisconnect(con)

This seems to work (within a reasonable hour or so), and allows me to work with predictions using dbplyr that might not otherwise be possible. But wondering if there are any ways to future optimise this, and potentially increase the ndraws possible?

Thanks!

@mjskay
Copy link
Owner

mjskay commented Aug 24, 2023

Yeah, we've had to do similar things on large numbers of predictions. It probably would be helpful to have some infrastructure to make this kind of workflow easier, e.g. by providing columns to use to define batches to generate predictions from or something like that.

Re: ndraws specifically, can you make a predictions from one row of this model with all draws fine? I guess I'm not clear on exactly what the problem is...

@petermacp
Copy link
Author

Thanks so much.

I guess my question is whether you have any tips to make the add_epred_draws() call as fast as possible, and output a smaller object?

Peter

@mjskay
Copy link
Owner

mjskay commented Aug 26, 2023

You could try using add_epred_rvars instead, as the rvar format in a data frame will be more compact than the format output by add_epred_draws.

Beyond that the obvious way to speed this up is to run the computations for different splits in parallel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants