diff --git a/src/fairness.jl b/src/fairness.jl index 90c80ce..a99ec0c 100644 --- a/src/fairness.jl +++ b/src/fairness.jl @@ -67,7 +67,7 @@ module Fairness end end - function _demographic_parity(cohorts, funcs, conn, reference_subjects, process_size, silver) + function _demographic_parity(cohorts::Vector{<:Any}, funcs, conn, reference_subjects, process_size, silver) _funcs = [Fix2(fun, conn) for fun in funcs] @@ -128,7 +128,7 @@ module Fairness return dps end - function _demographic_parity(cohorts, funcs, conn, reference_subjects, process_size) + function _demographic_parity(cohorts::Vector{<:Any}, funcs, conn, reference_subjects, process_size) _funcs = [Fix2(fun, conn) for fun in funcs] @@ -167,6 +167,45 @@ module Fairness return dps end + function _demographic_parity(cohorts::DataFrame, funcs, conn, reference_subjects, process_size) + + _funcs = [Fix2(fun, conn) for fun in funcs] + + if isempty(reference_subjects) + reference_subjects = GetDatabasePersonIDs(conn) + end + + cohorts = cohorts.subject_id + + subsets = _subset_subjects(reference_subjects, process_size) + + denom = DataFrame() + for sub in subsets + denom = vcat(denom, _counter_reducer(sub, :count_denom, _funcs)) + end + + denom = groupby(denom, names(denom)[1:end-1]) |> + x -> combine(x, :count_denom => sum => :count_denom) + + subsets = _subset_subjects(cohorts, process_size) + + num = DataFrame() + for sub in subsets + num = vcat(num, _counter_reducer(sub, :count_num, _funcs)) + end + + num = groupby(num, names(num)[1:end-1]) |> + x -> combine(x, :count_num => sum => :count_num) + + dps = outerjoin(num, denom; on = names(num)[1:end-1] .|> + x -> Symbol(x) => Symbol(x)) |> + x -> coalesce.(x, 0) + + dps.demographic_parity = dps.count_num ./ dps.count_denom + + return dps + end + function equality_of_opportunity(cohorts, funcs, conn; reference_subjects = "", process_size = 10000) _funcs = [Fix2(fun, conn) for fun in funcs]