Skip to content

Commit

Permalink
Adding SDM comparing union of T and C before and after PS adjustment …
Browse files Browse the repository at this point in the history
…(when estimating ATE). Improving calculation of overall SD.
  • Loading branch information
Admin_mschuemi authored and Admin_mschuemi committed Aug 29, 2023
1 parent e1b07c6 commit 787888d
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 24 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: CohortMethod
Type: Package
Title: New-User Cohort Method with Large Scale Propensity and Outcome Models
Version: 5.1.1
Date: 2023-05-19
Version: 5.2.0
Date: 2023-08-28
Authors@R: c(
person("Martijn", "Schuemie", , "[email protected]", role = c("aut", "cre")),
person("Marc", "Suchard", role = c("aut")),
Expand Down Expand Up @@ -43,7 +43,7 @@ Imports:
cli,
pillar,
Rcpp (>= 0.11.2),
SqlRender (>= 1.7.0),
SqlRender (>= 1.12.0),
survival,
ParallelLogger (>= 3.0.1),
bit64,
Expand Down
9 changes: 8 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
CohortMethod 5.1.1
CohortMethod 5.2.0
==================

Changes:

1. The `computeCovariateBalance()` function now also computes standardized difference of mean comparing cohorts before and after PS adjustment, which can inform on generalizability.

2. Improved computation of overall standard deviation when computing covariate balance. Should produce more accurate balance estimations.


Bugfixes:

1. Now passing outcome-specific `riskWindowEnd` argument in `runCmAnalyses()` when specified.
Expand Down
96 changes: 79 additions & 17 deletions R/Balance.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,75 +70,125 @@ computeMeansPerGroup <- function(cohorts, cohortMethodData, covariateFilter) {
mutate(weight = 1 / .data$n) %>%
inner_join(cohorts, by = c("stratumId", "treatment")) %>%
select("rowId", "treatment", "weight")
# Overall weight is for computing mean and SD across T and C
overallW <- stratumSize %>%
group_by(.data$stratumId) %>%
summarise(weight = 1 / sum(.data$n, na.rm = TRUE)) %>%
ungroup() %>%
inner_join(cohorts, by = c("stratumId")) %>%
select("rowId", "weight")
} else {
w <- cohorts %>%
mutate(weight = .data$iptw) %>%
select("rowId", "treatment", "weight")
overallW <- w
}
# Normalize so sum(weight) == 1 per treatment arm:
wSum <- w %>%
group_by(.data$treatment) %>%
summarize(wSum = sum(.data$weight, na.rm = TRUE)) %>%
ungroup()
overallWSum <- overallW %>%
summarize(overallWSum = sum(.data$weight, na.rm = TRUE)) %>%
pull()

cohortMethodData$w <- w %>%
inner_join(wSum, by = "treatment") %>%
mutate(weight = .data$weight / .data$wSum) %>%
select("rowId", "treatment", "weight")

# By definition:
sumW <- 1
cohortMethodData$overallW <- overallW %>%
mutate(overallWeight = .data$weight / overallWSum) %>%
select("rowId", "overallWeight")

# Note: using abs() because due to rounding to machine precision number can become slightly negative:
result <- covariates %>%
inner_join(cohortMethodData$w, by = c("rowId")) %>%
inner_join(cohortMethodData$overallW, by = c("rowId")) %>%
group_by(.data$covariateId, .data$treatment) %>%
summarise(
sum = sum(as.numeric(.data$covariateValue), na.rm = TRUE),
mean = sum(.data$weight * as.numeric(.data$covariateValue), na.rm = TRUE),
overallMean = sum(.data$overallWeight * as.numeric(.data$covariateValue), na.rm = TRUE),
sumSqr = sum(.data$weight * as.numeric(.data$covariateValue)^2, na.rm = TRUE),
sumWSqr = sum(.data$weight^2, na.rm = TRUE)
sumWSqr = sum(.data$weight^2, na.rm = TRUE),
overallSumSqr = sum(.data$overallWeight * as.numeric(.data$covariateValue)^2, na.rm = TRUE),
overallSumWSqr = sum(.data$overallWeight^2, na.rm = TRUE),
.groups = "drop"
) %>%
mutate(sd = sqrt(abs(.data$sumSqr - .data$mean^2) * sumW / (sumW^2 - .data$sumWSqr))) %>%
ungroup() %>%
select("covariateId", "treatment", "sum", "mean", "sd") %>%
select("covariateId", "treatment", "sum", "mean", "sd", "overallMean", "overallSumSqr", "overallSumWSqr") %>%
collect()

cohortMethodData$w <- NULL
cohortMethodData$overallW <- NULL
} else {
# Don't use weighting
cohortCounts <- cohorts %>%
group_by(.data$treatment) %>%
count()
overallCount <- cohorts %>%
count() %>%
pull()

result <- covariates %>%
inner_join(select(cohorts, "rowId", "treatment"), by = "rowId") %>%
group_by(.data$covariateId, .data$treatment) %>%
summarise(
sum = sum(as.numeric(.data$covariateValue), na.rm = TRUE),
sumSqr = sum(as.numeric(.data$covariateValue)^2, na.rm = TRUE)
sumSqr = sum(as.numeric(.data$covariateValue)^2, na.rm = TRUE),
overallSumWSqr = sum(1 / overallCount^2, na.rm = TRUE),
.groups = "drop"
) %>%
inner_join(cohortCounts, by = "treatment") %>%
mutate(
sd = sqrt((.data$sumSqr - (.data$sum^2 / .data$n)) / .data$n),
mean = .data$sum / .data$n
mean = .data$sum / .data$n,
overallMean = .data$sum / overallCount,
overallSumSqr = sumSqr / overallCount
) %>%
ungroup() %>%
select("covariateId", "treatment", "sum", "mean", "sd") %>%
select("covariateId", "treatment", "sum", "mean", "sd", "overallMean", "overallSumSqr", "overallSumWSqr") %>%
collect()
}
target <- result %>%
filter(.data$treatment == 1) %>%
select("covariateId", sumTarget = "sum", meanTarget = "mean", sdTarget = "sd")
select("covariateId",
sumTarget = "sum",
meanTarget = "mean",
sdTarget = "sd",
overallMeanTarget = "overallMean",
overallSumSqrTarget = "overallSumSqr",
overallSumWSqrTarget = "overallSumWSqr")

comparator <- result %>%
filter(.data$treatment == 0) %>%
select("covariateId", sumComparator = "sum", meanComparator = "mean", sdComparator = "sd")
select("covariateId",
sumComparator = "sum",
meanComparator = "mean",
sdComparator = "sd",
overallMeanComparator = "overallMean",
overallSumSqrComparator = "overallSumSqr",
overallSumWSqrComparator = "overallSumWSqr")

# By definition:
sumW <- 1

result <- target %>%
full_join(comparator, by = "covariateId") %>%
mutate(sd = sqrt((.data$sdTarget^2 + .data$sdComparator^2) / 2))

return(result)
mutate(mean = .data$overallMeanTarget + .data$overallMeanComparator,
overallSumSqr = .data$overallSumSqrTarget + .data$overallSumSqrComparator,
overallSumWSqr = .data$overallSumWSqrTarget + .data$overallSumWSqrComparator) %>%
mutate(sd = sqrt(abs(.data$overallSumSqr - .data$mean^2) * sumW / (sumW^2 - .data$overallSumWSqr))) %>%
select(-"overallMeanTarget",
-"overallMeanComparator",
-"overallSumSqrTarget",
-"overallSumSqrComparator",
-"overallSumWSqrTarget",
-"overallSumWSqrComparator")

return(result)
}

#' Compute covariate balance before and after PS adjustment
Expand Down Expand Up @@ -179,13 +229,15 @@ computeMeansPerGroup <- function(cohorts, cohortMethodData, covariateFilter) {
#' - beforeMatchingSumComparator: The (weighted) sum value in the comparator before PS adjustment.
#' - beforeMatchingSdTarget: The standard deviation of the value in the target before PS adjustment.
#' - beforeMatchingSdComparator: The standard deviation of the value in the comparator before PS adjustment.
#' - beforeMatchingMean: The mean of the value across target and comparator before PS adjustment.
#' - beforeMatchingSd: The standard deviation of the value across target and comparator before PS adjustment.
#' - afterMatchingMeanTarget: The (weighted) mean value in the target after PS adjustment.
#' - afterMatchingMeanComparator: The (weighted) mean value in the comparator after PS adjustment.
#' - afterMatchingSumTarget: The (weighted) sum value in the target after PS adjustment.
#' - afterMatchingSumComparator: The (weighted) sum value in the comparator after PS adjustment.
#' - afterMatchingSdTarget: The standard deviation of the value in the target after PS adjustment.
#' - afterMatchingSdComparator: The standard deviation of the value in the comparator after PS adjustment.
#' - afterMatchingMean: The mean of the value across target and comparator after PS adjustment.
#' - afterMatchingSd: The standard deviation of the value across target and comparator after PS adjustment.
#' - beforeMatchingStdDiff: The standardized difference of means when comparing the target to
#' the comparator before PS adjustment.
Expand All @@ -195,14 +247,17 @@ computeMeansPerGroup <- function(cohorts, cohortMethodData, covariateFilter) {
#' before PS adjustment to the target after PS adjustment.
#' - comparatorStdDiff: The standardized difference of means when comparing the comparator
#' before PS adjustment to the comparator after PS adjustment.
#' -targetComparatorStdDiff: The standardized difference of means when comparing the entire
#' population before PS adjustment to the entire population after
#' PS adjustment.
#'
#' The 'beforeMatchingStdDiff' and 'afterMatchingStdDiff' columns inform on the balance:
#' are the target and comparator sufficiently similar in terms of baseline covariates to
#' allow for valid causal estimation?
#'
#' The 'targetStdDiff' and 'comparatorStdDiff' columns inform on the generalizability:
#' are the cohorts after PS adjustment sufficiently similar to the cohorts before adjustment
#' to allow generalizing the findings to the original cohorts?
#' The 'targetStdDiff', 'comparatorStdDiff', and 'targetComparatorStdDiff' columns inform on
#' the generalizability: are the cohorts after PS adjustment sufficiently similar to the cohorts
#' before adjustment to allow generalizing the findings to the original cohorts?
#'
#' @references
#' Austin, P.C. (2008) Assessing balance in measured baseline covariates when using many-to-one
Expand Down Expand Up @@ -292,6 +347,7 @@ computeCovariateBalance <- function(population,
beforeMatchingSumComparator = "sumComparator",
beforeMatchingSdTarget = "sdTarget",
beforeMatchingSdComparator = "sdComparator",
beforeMatchingMean = "mean",
beforeMatchingSd = "sd")
afterMatching <- afterMatching %>%
select("covariateId",
Expand All @@ -301,7 +357,9 @@ computeCovariateBalance <- function(population,
afterMatchingSumComparator = "sumComparator",
afterMatchingSdTarget = "sdTarget",
afterMatchingSdComparator = "sdComparator",
afterMatchingSd = "sd")
afterMatchingMean = "mean",
afterMatchingSd = "sd",
matches("overallMean"))
balance <- beforeMatching %>%
full_join(afterMatching, by = "covariateId") %>%
inner_join(collect(cohortMethodData$covariateRef), by = "covariateId") %>%
Expand Down Expand Up @@ -329,11 +387,15 @@ computeCovariateBalance <- function(population,
.data$beforeMatchingSdComparator == 0,
0,
(.data$beforeMatchingMeanComparator - .data$afterMatchingMeanComparator) / .data$beforeMatchingSdComparator
),
targetComparatorStdDiff = if_else(
.data$beforeMatchingSd == 0,
0,
(.data$beforeMatchingMean - .data$afterMatchingMean) / .data$beforeMatchingSd
)

) %>%
arrange(desc(abs(.data$beforeMatchingStdDiff)))
# TODO: Compute generalizability across T and C

delta <- Sys.time() - start
message(paste("Computing covariate balance took", signif(delta, 3), attr(delta, "units")))
Expand Down
67 changes: 67 additions & 0 deletions extras/TestCovBalance.R
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,70 @@ newXbalance <- function(fmla, strataColumn = NULL , data, report = c("std.diffs"
class(ans) <- c("xbal", "list")
ans
}

# Code for testing generalizability metrics ------------------------------------
# saveRDS(population, "d:/temp/studyPop.rds")
# saveCohortMethodData(cohortMethodData, "d:/temp/cmData.zip")


population <- readRDS("d:/temp/studyPop.rds")
cohortMethodData <- loadCohortMethodData("d:/temp/cmData.zip")
cohorts <- cohortMethodData$cohorts %>%
collect()

bal <- computeCovariateBalance(population, cohortMethodData) %>%
arrange(covariateId)


tPlusCBefore <- cohortMethodData$cohorts %>%
collect() %>%
select("rowId") %>%
mutate(treatment = 1)
tPlusCAfter <- population %>%
# mutate(stratumId = stratumId + treatment * (1 + max(population$stratumId))) %>%
select("rowId", "stratumId") %>%
mutate(treatment = 0)
adjustedCohorts <- bind_rows(tPlusCBefore, tPlusCAfter)
# cohortMethodData$adjustedCohorts <- adjustedCohorts
# adjustedCohorts <- cohortMethodData$adjustedCohorts
dummyBal <- CohortMethod:::computeMeansPerGroup(cohorts = adjustedCohorts, cohortMethodData, NULL) %>%
arrange(covariateId)

# Compute mean before the hard way:
cohortMethodData$cohorts %>%
left_join(cohortMethodData$covariates %>%
filter(covariateId == 1007),
by = join_by("rowId")) %>%
mutate(covariateValue = if_else(is.na(covariateValue), 0, covariateValue)) %>%
summarise(mean(covariateValue),
sd(covariateValue))
# Using the dummy cov balance:
dummyBal %>%
select(meanTarget, sdTarget) %>%
head(10)
# Compute mean in before using computeCovariateBalance output:
# Using insight that "The exact pooled variance is the mean of the variances
# plus the variance of the means of the component data sets." from
# https://arxiv.org/ftp/arxiv/papers/1007/1007.1012.pdf
bal %>%
mutate(meanBefore = beforeMatchingMeanTarget * mean(cohorts$treatment) + beforeMatchingMeanComparator * mean(!cohorts$treatment)) %>%
mutate(beforeVarTarget = beforeMatchingSdTarget^2,
beforeVarComparator = beforeMatchingSdComparator^2) %>%
mutate(meanVar = beforeVarTarget * mean(cohorts$treatment) + beforeVarComparator * mean(!cohorts$treatment),
varOfMeans = (beforeMatchingMeanTarget-meanBefore)^2 * mean(cohorts$treatment) + (beforeMatchingMeanComparator-meanBefore)^2 * mean(!cohorts$treatment)) %>%
mutate(sdBefore = sqrt(meanVar + varOfMeans)) %>%
select(meanBefore, sdBefore) %>%
head(10)
bal %>%
select("beforeMatchingMean", "beforeMatchingSd") %>%
head(10)

# Same for after matching:
dummyBal %>%
select(meanComparator, sdComparator) %>%
head(10)
bal %>%
select("afterMatchingMean", "afterMatchingSd") %>%
head(10)


11 changes: 8 additions & 3 deletions man/computeCovariateBalance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 787888d

Please sign in to comment.