diff --git a/sksurv/nonparametric.py b/sksurv/nonparametric.py index 82ae7364..9991a8d9 100644 --- a/sksurv/nonparametric.py +++ b/sksurv/nonparametric.py @@ -189,16 +189,29 @@ def _ci_logmlog(prob_survival, sigma_t, z): return ci +def _ci_greenwood(prob_survival, sigma_t, z): + """Compute the pointwise Greenwood confidence intervals + See https://www.math.wustl.edu/~sawyer/handouts/greenwood.pdf. + """ + var_survival_prob = prob_survival**2 * sigma_t**2 + ci = prob_survival + np.array([[-1], [1]]) * z * np.sqrt(var_survival_prob) + return ci + + def _km_ci_estimator(prob_survival, ratio_var, conf_level, conf_type): - if conf_type not in {"log-log"}: - raise ValueError(f"conf_type must be None or a str among {{'log-log'}}, but was {conf_type!r}") + if conf_type not in {"log-log", "greenwood"}: + raise ValueError(f"conf_type must be None or a str among {{'log-log', 'greenwood'}}, but was {conf_type!r}") if not isinstance(conf_level, numbers.Real) or not np.isfinite(conf_level) or conf_level <= 0 or conf_level >= 1.0: raise ValueError(f"conf_level must be a float in the range (0.0, 1.0), but was {conf_level!r}") z = stats.norm.isf((1.0 - conf_level) / 2.0) - sigma = np.sqrt(np.cumsum(ratio_var)) - ci = _ci_logmlog(prob_survival, sigma, z) + sigma_t = np.sqrt(np.cumsum(ratio_var)) + + if conf_type == "log-log": + ci = _ci_logmlog(prob_survival, sigma_t, z) + elif conf_type == "greenwood": + ci = _ci_greenwood(prob_survival, sigma_t, z) return ci @@ -241,11 +254,13 @@ def kaplan_meier_estimator( conf_level : float, optional, default: 0.95 The level for a two-sided confidence interval on the survival curves. - conf_type : None or {'log-log'}, optional, default: 'log-log'. + conf_type : None or {'log-log', 'greenwood'}, optional, default: 'log-log'. The type of confidence intervals to estimate. If `None`, no confidence intervals are estimated. If "log-log", estimate confidence intervals using the log hazard or :math:`log(-log(S(t)))` as described in [2]_. + If "greenwood", estimate confidence intervals using + the Greenwood variance formula. Returns ------- @@ -421,11 +436,13 @@ class SurvivalFunctionEstimator(BaseEstimator): conf_level : float, optional, default: 0.95 The level for a two-sided confidence interval on the survival curves. - conf_type : None or {'log-log'}, optional, default: 'log-log'. + conf_type : None or {'log-log', 'greenwood'}, optional, default: 'log-log'. The type of confidence intervals to estimate. If `None`, no confidence intervals are estimated. If "log-log", estimate confidence intervals using the log hazard or :math:`log(-log(S(t)))`. + If "greenwood", estimate confidence intervals using + the Greenwood variance formula. See also -------- @@ -435,7 +452,7 @@ class SurvivalFunctionEstimator(BaseEstimator): _parameter_constraints = { "conf_level": [Interval(numbers.Real, 0.0, 1.0, closed="neither")], - "conf_type": [None, StrOptions({"log-log"})], + "conf_type": [None, StrOptions({"log-log", "greenwood"})], } def __init__(self, conf_level=0.95, conf_type=None): diff --git a/tests/test_nonparametric.py b/tests/test_nonparametric.py index 84b4cd64..ee29383a 100644 --- a/tests/test_nonparametric.py +++ b/tests/test_nonparametric.py @@ -20,6 +20,9 @@ class SimpleDataKMCases(FixtureParameterFactory): + def __init__(self, conf_type): + self.conf_type = conf_type + @property def time(self): return [1, 2, 2, 3, 7, 6, 5, 5, 3, 9, 11, 13, 17, 13, 6, 23] @@ -32,7 +35,7 @@ def data_all_uncensored(self): time = self.time event = np.repeat(True, len(time)) true_y = np.array([0.9375, 0.8125, 0.6875, 0.5625, 0.4375, 0.375, 0.3125, 0.25, 0.125, 0.0625, 0]) - km_ci = np.array( + km_ci_logmlog = np.array( [ [ 0.632345441738904, @@ -62,6 +65,43 @@ def data_all_uncensored(self): ], ] ) + + km_ci_greenwood = np.array( + [ + [ + 0.81889206, + 0.62125045, + 0.46038309, + 0.31942606, + 0.19442606, + 0.13778413, + 0.08538309, + 0.03782767, + -0.03704929, + -0.05610794, + 0.0, + ], + [ + 1.05610794, + 1.00374955, + 0.91461691, + 0.80557394, + 0.68057394, + 0.61221587, + 0.53961691, + 0.46217233, + 0.28704929, + 0.18110794, + 0.0, + ], + ] + ) + + if self.conf_type == "greenwood": + km_ci = km_ci_greenwood + elif self.conf_type == "log-log": + km_ci = km_ci_logmlog + return time, event, self.true_x, true_y, km_ci def data_all_censored(self): @@ -69,8 +109,15 @@ def data_all_censored(self): event = np.repeat(False, len(time)) true_x = self.true_x true_y = np.ones(true_x.shape[0]) - km_var = np.ones((2, true_x.shape[0])) - return time, event, true_x, true_y, km_var + km_ci_logmlog = np.ones((2, true_x.shape[0])) + km_ci_greenwood = np.ones((2, true_x.shape[0])) + + if self.conf_type == "greenwood": + km_ci = km_ci_greenwood + elif self.conf_type == "log-log": + km_ci = km_ci_logmlog + + return time, event, true_x, true_y, km_ci def data_first_censored(self): time = self.time @@ -91,7 +138,7 @@ def data_first_censored(self): 0, ] ) - km_var = np.array( + km_ci_logmlog = np.array( [ [ 1.0, @@ -122,14 +169,50 @@ def data_first_censored(self): ] ) - return time, event, self.true_x, true_y, km_var + km_ci_greenwood = np.array( + [ + [ + 1.0, + 0.69463917, + 0.50954495, + 0.35208199, + 0.21419932, + 0.15208199, + 0.09477411, + 0.04287828, + -0.03869417, + -0.05956701, + 0.0, + ], + [ + 1.0, + 1.03869417, + 0.95712172, + 0.84791801, + 0.71913401, + 0.64791801, + 0.57189255, + 0.49045505, + 0.30536083, + 0.19290034, + 0.0, + ], + ] + ) + + if self.conf_type == "greenwood": + km_ci = km_ci_greenwood + elif self.conf_type == "log-log": + km_ci = km_ci_logmlog + + return time, event, self.true_x, true_y, km_ci def data_last_censored(self): time = self.time event = np.repeat(True, len(time)) event[-1] = False true_y = np.array([0.9375, 0.8125, 0.6875, 0.5625, 0.4375, 0.375, 0.3125, 0.25, 0.125, 0.0625, 0.0625]) - km_var = np.array( + km_ci_logmlog = np.array( [ [ 0.632345441738904, @@ -159,7 +242,44 @@ def data_last_censored(self): ], ] ) - return time, event, self.true_x, true_y, km_var + + km_ci_greenwood = np.array( + [ + [ + 0.81889206, + 0.62125045, + 0.46038309, + 0.31942606, + 0.19442606, + 0.13778413, + 0.08538309, + 0.03782767, + -0.03704929, + -0.05610794, + -0.05610794, + ], + [ + 1.05610794, + 1.00374955, + 0.91461691, + 0.80557394, + 0.68057394, + 0.61221587, + 0.53961691, + 0.46217233, + 0.28704929, + 0.18110794, + 0.18110794, + ], + ] + ) + + if self.conf_type == "greenwood": + km_ci = km_ci_greenwood + elif self.conf_type == "log-log": + km_ci = km_ci_logmlog + + return time, event, self.true_x, true_y, km_ci def data_first_and_last_censored(self): time = self.time @@ -181,7 +301,7 @@ def data_first_and_last_censored(self): 0.0666666666666667, ] ) - km_var = np.array( + km_ci_logmlog = np.array( [ [ 1.0, @@ -211,7 +331,58 @@ def data_first_and_last_censored(self): ], ] ) - return time, event, self.true_x, true_y, km_var + + km_ci_greenwood = np.array( + [ + [ + 1.0, + 0.69463917, + 0.50954495, + 0.35208199, + 0.21419932, + 0.15208199, + 0.09477411, + 0.04287828, + -0.03869417, + -0.05956701, + -0.05956701, + ], + [ + 1.0, + 1.03869417, + 0.95712172, + 0.84791801, + 0.71913401, + 0.64791801, + 0.57189255, + 0.49045505, + 0.30536083, + 0.19290034, + 0.19290034, + ], + ] + ) + + if self.conf_type == "greenwood": + km_ci = km_ci_greenwood + elif self.conf_type == "log-log": + km_ci = km_ci_logmlog + + return time, event, self.true_x, true_y, km_ci + + def case_data(self, case): + if case == "all_censored": + return self.data_all_censored() + elif case == "all_uncensored": + return self.data_all_uncensored() + elif case == "first_censored": + return self.data_first_censored() + elif case == "last_censored": + return self.data_last_censored() + elif case == "first_and_last_censored": + return self.data_first_and_last_censored() + else: + ValueError("Case is not implemented") class SimpleDataNACases(FixtureParameterFactory): @@ -313,7 +484,7 @@ def data_first_and_last_censored(self): class Whas500CIData(FixtureParameterFactory): - def data_loglog_ci_95(self): + def data_logmlog_ci_95(self): true_ci = np.array( [ [ @@ -1114,7 +1285,808 @@ def data_loglog_ci_95(self): ) return 0.95, "log-log", true_ci - def data_loglog_ci_99(self): + def data_greenwood_ci_95(self): + true_ci = np.array( + [ + [ + 0.97300182, + 0.9525732, + 0.94524122, + 0.94041791, + 0.93563816, + 0.92384684, + 0.9099339, + 0.90305519, + 0.89395131, + 0.88942546, + 0.88716856, + 0.88266622, + 0.8759397, + 0.86924323, + 0.86479446, + 0.86035742, + 0.85814313, + 0.85593156, + 0.85151639, + 0.84491266, + 0.84271628, + 0.84052225, + 0.83833052, + 0.83614106, + 0.83395381, + 0.83176874, + 0.8295858, + 0.82740497, + 0.82304944, + 0.82087469, + 0.81870189, + 0.81653101, + 0.81219491, + 0.80786615, + 0.80570446, + 0.80354451, + 0.80138629, + 0.79922978, + 0.79707494, + 0.79492175, + 0.7927702, + 0.79062025, + 0.7884719, + 0.7863251, + 0.78417986, + 0.78203614, + 0.77989392, + 0.7777532, + 0.77561394, + 0.77347614, + 0.77133977, + 0.76920482, + 0.76707126, + 0.7649391, + 0.7628083, + 0.75855076, + 0.75642398, + 0.75429851, + 0.75217435, + 0.75005146, + 0.74792985, + 0.74369038, + 0.73945584, + 0.73734039, + 0.73522614, + 0.73311308, + 0.7310012, + 0.72889047, + 0.72678091, + 0.7225652, + 0.72045904, + 0.71835399, + 0.71625006, + 0.71414721, + 0.71204546, + 0.70784518, + 0.70574664, + 0.70364916, + 0.70155272, + 0.69945732, + 0.69736295, + 0.6952696, + 0.69317727, + 0.69108595, + 0.68690631, + 0.68481797, + 0.68481797, + 0.68481797, + 0.68481797, + 0.68481797, + 0.68268554, + 0.68055417, + 0.68055417, + 0.68055417, + 0.67840384, + 0.67625459, + 0.67625459, + 0.67625459, + 0.67625459, + 0.67625459, + 0.67406508, + 0.6718767, + 0.6718767, + 0.6718767, + 0.6718767, + 0.6718767, + 0.6718767, + 0.6718767, + 0.66963897, + 0.66963897, + 0.66739503, + 0.66739503, + 0.66739503, + 0.66739503, + 0.66739503, + 0.66739503, + 0.66739503, + 0.66509886, + 0.66509886, + 0.66276416, + 0.66276416, + 0.66276416, + 0.66276416, + 0.66276416, + 0.66276416, + 0.66276416, + 0.66276416, + 0.66033795, + 0.66033795, + 0.65790449, + 0.65547263, + 0.65547263, + 0.65547263, + 0.65302441, + 0.65302441, + 0.65302441, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.65055955, + 0.64796021, + 0.64796021, + 0.64534212, + 0.64271552, + 0.64009095, + 0.64009095, + 0.64009095, + 0.64009095, + 0.63740315, + 0.63740315, + 0.6347063, + 0.63201162, + 0.63201162, + 0.63201162, + 0.63201162, + 0.63201162, + 0.63201162, + 0.63201162, + 0.63201162, + 0.63201162, + 0.62920097, + 0.62639279, + 0.62639279, + 0.62639279, + 0.62356206, + 0.62073384, + 0.61790809, + 0.61508479, + 0.61226392, + 0.61226392, + 0.61226392, + 0.60941962, + 0.60657778, + 0.60657778, + 0.6037252, + 0.60087506, + 0.59802734, + 0.59802734, + 0.59516861, + 0.59231229, + 0.58945836, + 0.58660681, + 0.58375761, + 0.58091073, + 0.57806616, + 0.57522389, + 0.57238388, + 0.56954612, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.5667106, + 0.56361814, + 0.56361814, + 0.56361814, + 0.56361814, + 0.56044298, + 0.56044298, + 0.55725333, + 0.55725333, + 0.55725333, + 0.55725333, + 0.55725333, + 0.55399362, + 0.55399362, + 0.55399362, + 0.55069935, + 0.55069935, + 0.55069935, + 0.55069935, + 0.55069935, + 0.55069935, + 0.55069935, + 0.55069935, + 0.55069935, + 0.54720343, + 0.54720343, + 0.54720343, + 0.54720343, + 0.5436451, + 0.5436451, + 0.5436451, + 0.5436451, + 0.54002176, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53635519, + 0.53219307, + 0.53219307, + 0.53219307, + 0.53219307, + 0.53219307, + 0.53219307, + 0.53219307, + 0.53219307, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52776927, + 0.52278776, + 0.52278776, + 0.52278776, + 0.52278776, + 0.52278776, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.5124889, + 0.50613681, + 0.49981505, + 0.49352238, + 0.48725762, + 0.48101973, + 0.47480774, + 0.46862079, + 0.46245807, + 0.45631884, + 0.45020243, + 0.4441082, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.43803559, + 0.42965799, + 0.42965799, + 0.42965799, + 0.42965799, + 0.42965799, + 0.42965799, + 0.42965799, + 0.42965799, + 0.42965799, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.41970884, + 0.34626109, + 0.34626109, + 0.34626109, + 0.34626109, + 0.34626109, + 0.34626109, + 0.34626109, + 0.34626109, + 0.34626109, + 0.0504821, + -0.08965731, + 0.0, + ], + [ + 0.99499818, + 0.9834268, + 0.97875878, + 0.97558209, + 0.97236184, + 0.96415316, + 0.9540661, + 0.94894481, + 0.94204869, + 0.93857454, + 0.93683144, + 0.93333378, + 0.9280603, + 0.92275677, + 0.91920554, + 0.91564258, + 0.91385687, + 0.91206844, + 0.90848361, + 0.90308734, + 0.90128372, + 0.89947775, + 0.89766948, + 0.89585894, + 0.89404619, + 0.89223126, + 0.8904142, + 0.88859503, + 0.88495056, + 0.88312531, + 0.88129811, + 0.87946899, + 0.87580509, + 0.87213385, + 0.87029554, + 0.86845549, + 0.86661371, + 0.86477022, + 0.86292506, + 0.86107825, + 0.8592298, + 0.85737975, + 0.8555281, + 0.8536749, + 0.85182014, + 0.84996386, + 0.84810608, + 0.8462468, + 0.84438606, + 0.84252386, + 0.84066023, + 0.83879518, + 0.83692874, + 0.8350609, + 0.8331917, + 0.82944924, + 0.82757602, + 0.82570149, + 0.82382565, + 0.82194854, + 0.82007015, + 0.81630962, + 0.81254416, + 0.81065961, + 0.80877386, + 0.80688692, + 0.8049988, + 0.80310953, + 0.80121909, + 0.7974348, + 0.79554096, + 0.79364601, + 0.79174994, + 0.78985279, + 0.78795454, + 0.78415482, + 0.78225336, + 0.78035084, + 0.77844728, + 0.77654268, + 0.77463705, + 0.7727304, + 0.77082273, + 0.76891405, + 0.76509369, + 0.76318203, + 0.76318203, + 0.76318203, + 0.76318203, + 0.76318203, + 0.76123558, + 0.75928808, + 0.75928808, + 0.75928808, + 0.75732458, + 0.75535999, + 0.75535999, + 0.75535999, + 0.75535999, + 0.75535999, + 0.7533635, + 0.75136587, + 0.75136587, + 0.75136587, + 0.75136587, + 0.75136587, + 0.75136587, + 0.75136587, + 0.7493296, + 0.7493296, + 0.74728662, + 0.74728662, + 0.74728662, + 0.74728662, + 0.74728662, + 0.74728662, + 0.74728662, + 0.74520297, + 0.74520297, + 0.74308877, + 0.74308877, + 0.74308877, + 0.74308877, + 0.74308877, + 0.74308877, + 0.74308877, + 0.74308877, + 0.74090563, + 0.74090563, + 0.73871452, + 0.73652181, + 0.73652181, + 0.73652181, + 0.73431453, + 0.73431453, + 0.73431453, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.73209244, + 0.72977131, + 0.72977131, + 0.72743354, + 0.7250863, + 0.72273705, + 0.72273705, + 0.72273705, + 0.72273705, + 0.72033967, + 0.72033967, + 0.71793222, + 0.71552261, + 0.71552261, + 0.71552261, + 0.71552261, + 0.71552261, + 0.71552261, + 0.71552261, + 0.71552261, + 0.71552261, + 0.71302801, + 0.71053093, + 0.71053093, + 0.71053093, + 0.70801397, + 0.7054945, + 0.70297255, + 0.70044815, + 0.69792133, + 0.69792133, + 0.69792133, + 0.69537392, + 0.69282405, + 0.69282405, + 0.69026246, + 0.68769842, + 0.68513197, + 0.68513197, + 0.68255359, + 0.67997279, + 0.6773896, + 0.67480404, + 0.67221612, + 0.66962588, + 0.66703334, + 0.6644385, + 0.66184139, + 0.65924203, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.65664043, + 0.6538514, + 0.6538514, + 0.6538514, + 0.6538514, + 0.65099948, + 0.65099948, + 0.64813192, + 0.64813192, + 0.64813192, + 0.64813192, + 0.64813192, + 0.64521017, + 0.64521017, + 0.64521017, + 0.64225858, + 0.64225858, + 0.64225858, + 0.64225858, + 0.64225858, + 0.64225858, + 0.64225858, + 0.64225858, + 0.64225858, + 0.63916358, + 0.63916358, + 0.63916358, + 0.63916358, + 0.63601927, + 0.63601927, + 0.63601927, + 0.63601927, + 0.63282374, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62959121, + 0.62603184, + 0.62603184, + 0.62603184, + 0.62603184, + 0.62603184, + 0.62603184, + 0.62603184, + 0.62603184, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.62229913, + 0.61822498, + 0.61822498, + 0.61822498, + 0.61822498, + 0.61822498, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60981871, + 0.60494773, + 0.60004641, + 0.59511601, + 0.59015769, + 0.58517251, + 0.58016141, + 0.57512529, + 0.57006493, + 0.56498109, + 0.55987442, + 0.55474557, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54959511, + 0.54323195, + 0.54323195, + 0.54323195, + 0.54323195, + 0.54323195, + 0.54323195, + 0.54323195, + 0.54323195, + 0.54323195, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.53611286, + 0.5299088, + 0.5299088, + 0.5299088, + 0.5299088, + 0.5299088, + 0.5299088, + 0.5299088, + 0.5299088, + 0.5299088, + 0.53363116, + 0.38171394, + 0.0, + ], + ] + ) + return 0.95, "greenwood", true_ci + + def data_logmlog_ci_99(self): true_ci = np.array( [ [ @@ -1915,6 +2887,807 @@ def data_loglog_ci_99(self): ) return 0.99, "log-log", true_ci + def data_greenwood_ci_99(self): + true_ci = np.array( + [ + [ + 0.96954595, + 0.94772575, + 0.93997523, + 0.93489322, + 0.92986845, + 0.91751426, + 0.90300023, + 0.89584541, + 0.88639467, + 0.88170358, + 0.87936596, + 0.87470577, + 0.86775096, + 0.86083564, + 0.85624586, + 0.85167149, + 0.84938986, + 0.84711182, + 0.84256619, + 0.83577275, + 0.83351466, + 0.83125966, + 0.8290077, + 0.8267587, + 0.82451261, + 0.82226939, + 0.82002897, + 0.81779131, + 0.81332407, + 0.8110944, + 0.8088673, + 0.80664274, + 0.80220103, + 0.79776897, + 0.79555646, + 0.79334626, + 0.79113832, + 0.78893263, + 0.78672913, + 0.78452782, + 0.78232864, + 0.78013158, + 0.7779366, + 0.77574368, + 0.7735528, + 0.77136392, + 0.76917702, + 0.76699207, + 0.76480906, + 0.76262795, + 0.76044873, + 0.75827138, + 0.75609586, + 0.75392216, + 0.75175027, + 0.7474118, + 0.74524519, + 0.7430803, + 0.74091711, + 0.73875561, + 0.73659578, + 0.73228107, + 0.72797284, + 0.72582111, + 0.72367096, + 0.72152237, + 0.71937533, + 0.71722982, + 0.71508582, + 0.71080233, + 0.70866282, + 0.70652476, + 0.70438816, + 0.70225301, + 0.70011928, + 0.69585607, + 0.69372657, + 0.69159845, + 0.68947171, + 0.68734633, + 0.68522231, + 0.68309963, + 0.68097829, + 0.67885827, + 0.67462218, + 0.67250609, + 0.67250609, + 0.67250609, + 0.67250609, + 0.67250609, + 0.67034444, + 0.66818418, + 0.66818418, + 0.66818418, + 0.6660045, + 0.66382623, + 0.66382623, + 0.66382623, + 0.66382623, + 0.66382623, + 0.66160639, + 0.65938805, + 0.65938805, + 0.65938805, + 0.65938805, + 0.65938805, + 0.65938805, + 0.65938805, + 0.65711866, + 0.65711866, + 0.65484315, + 0.65484315, + 0.65484315, + 0.65484315, + 0.65484315, + 0.65484315, + 0.65484315, + 0.65251359, + 0.65251359, + 0.65014425, + 0.65014425, + 0.65014425, + 0.65014425, + 0.65014425, + 0.65014425, + 0.65014425, + 0.65014425, + 0.64767985, + 0.64767985, + 0.64520832, + 0.64273888, + 0.64273888, + 0.64273888, + 0.64025281, + 0.64025281, + 0.64025281, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.6377498, + 0.63510675, + 0.63510675, + 0.63244462, + 0.62977413, + 0.62710631, + 0.62710631, + 0.62710631, + 0.62710631, + 0.62437288, + 0.62437288, + 0.62163056, + 0.6188911, + 0.6188911, + 0.6188911, + 0.6188911, + 0.6188911, + 0.6188911, + 0.6188911, + 0.6188911, + 0.6188911, + 0.61603078, + 0.61317373, + 0.61317373, + 0.61317373, + 0.61029371, + 0.60741697, + 0.60454349, + 0.60167324, + 0.59880617, + 0.59880617, + 0.59880617, + 0.59591523, + 0.59302751, + 0.59302751, + 0.59012921, + 0.58723412, + 0.58434221, + 0.58434221, + 0.58143943, + 0.57853983, + 0.57564337, + 0.57275002, + 0.56985976, + 0.56697257, + 0.56408841, + 0.56120726, + 0.55832909, + 0.55545388, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.5525816, + 0.54944146, + 0.54944146, + 0.54944146, + 0.54944146, + 0.54621552, + 0.54621552, + 0.54297527, + 0.54297527, + 0.54297527, + 0.54297527, + 0.54297527, + 0.53966246, + 0.53966246, + 0.53966246, + 0.53631436, + 0.53631436, + 0.53631436, + 0.53631436, + 0.53631436, + 0.53631436, + 0.53631436, + 0.53631436, + 0.53631436, + 0.53275544, + 0.53275544, + 0.53275544, + 0.53275544, + 0.52913206, + 0.52913206, + 0.52913206, + 0.52913206, + 0.52544151, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.52170675, + 0.51744993, + 0.51744993, + 0.51744993, + 0.51744993, + 0.51744993, + 0.51744993, + 0.51744993, + 0.51744993, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.51291755, + 0.50779348, + 0.50779348, + 0.50779348, + 0.50779348, + 0.50779348, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49719728, + 0.49061248, + 0.48406757, + 0.47756086, + 0.47109084, + 0.46465614, + 0.45825548, + 0.45188772, + 0.44555181, + 0.43924677, + 0.43297171, + 0.42672581, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.42050832, + 0.41181423, + 0.41181423, + 0.41181423, + 0.41181423, + 0.41181423, + 0.41181423, + 0.41181423, + 0.41181423, + 0.41181423, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.40142044, + 0.31740794, + 0.31740794, + 0.31740794, + 0.31740794, + 0.31740794, + 0.31740794, + 0.31740794, + 0.31740794, + 0.31740794, + -0.02542612, + -0.1637151, + 0.0, + ], + [ + 0.99845405, + 0.98827425, + 0.98402477, + 0.98110678, + 0.97813155, + 0.97048574, + 0.96099977, + 0.95615459, + 0.94960533, + 0.94629642, + 0.94463404, + 0.94129423, + 0.93624904, + 0.93116436, + 0.92775414, + 0.92432851, + 0.92261014, + 0.92088818, + 0.91743381, + 0.91222725, + 0.91048534, + 0.90874034, + 0.9069923, + 0.9052413, + 0.90348739, + 0.90173061, + 0.89997103, + 0.89820869, + 0.89467593, + 0.8929056, + 0.8911327, + 0.88935726, + 0.88579897, + 0.88223103, + 0.88044354, + 0.87865374, + 0.87686168, + 0.87506737, + 0.87327087, + 0.87147218, + 0.86967136, + 0.86786842, + 0.8660634, + 0.86425632, + 0.8624472, + 0.86063608, + 0.85882298, + 0.85700793, + 0.85519094, + 0.85337205, + 0.85155127, + 0.84972862, + 0.84790414, + 0.84607784, + 0.84424973, + 0.8405882, + 0.83875481, + 0.8369197, + 0.83508289, + 0.83324439, + 0.83140422, + 0.82771893, + 0.82402716, + 0.82217889, + 0.82032904, + 0.81847763, + 0.81662467, + 0.81477018, + 0.81291418, + 0.80919767, + 0.80733718, + 0.80547524, + 0.80361184, + 0.80174699, + 0.79988072, + 0.79614393, + 0.79427343, + 0.79240155, + 0.79052829, + 0.78865367, + 0.78677769, + 0.78490037, + 0.78302171, + 0.78114173, + 0.77737782, + 0.77549391, + 0.77549391, + 0.77549391, + 0.77549391, + 0.77549391, + 0.77357669, + 0.77165808, + 0.77165808, + 0.77165808, + 0.76972392, + 0.76778835, + 0.76778835, + 0.76778835, + 0.76778835, + 0.76778835, + 0.76582218, + 0.76385452, + 0.76385452, + 0.76385452, + 0.76385452, + 0.76385452, + 0.76385452, + 0.76385452, + 0.76184991, + 0.76184991, + 0.7598385, + 0.7598385, + 0.7598385, + 0.7598385, + 0.7598385, + 0.7598385, + 0.7598385, + 0.75778824, + 0.75778824, + 0.75570868, + 0.75570868, + 0.75570868, + 0.75570868, + 0.75570868, + 0.75570868, + 0.75570868, + 0.75570868, + 0.75356373, + 0.75356373, + 0.7514107, + 0.74925556, + 0.74925556, + 0.74925556, + 0.74708614, + 0.74708614, + 0.74708614, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74490218, + 0.74262477, + 0.74262477, + 0.74033103, + 0.73802769, + 0.7357217, + 0.7357217, + 0.7357217, + 0.7357217, + 0.73336994, + 0.73336994, + 0.73100796, + 0.72864313, + 0.72864313, + 0.72864313, + 0.72864313, + 0.72864313, + 0.72864313, + 0.72864313, + 0.72864313, + 0.72864313, + 0.72619819, + 0.72374999, + 0.72374999, + 0.72374999, + 0.72128232, + 0.71881136, + 0.71633714, + 0.71385971, + 0.71137908, + 0.71137908, + 0.71137908, + 0.70887831, + 0.70637432, + 0.70637432, + 0.70385845, + 0.70133936, + 0.6988171, + 0.6988171, + 0.69628276, + 0.69374525, + 0.6912046, + 0.68866082, + 0.68611397, + 0.68356405, + 0.68101109, + 0.67845513, + 0.67589618, + 0.67333428, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.67076943, + 0.66802808, + 0.66802808, + 0.66802808, + 0.66802808, + 0.66522694, + 0.66522694, + 0.66240998, + 0.66240998, + 0.66240998, + 0.66240998, + 0.66240998, + 0.65954133, + 0.65954133, + 0.65954133, + 0.65664358, + 0.65664358, + 0.65664358, + 0.65664358, + 0.65664358, + 0.65664358, + 0.65664358, + 0.65664358, + 0.65664358, + 0.65361157, + 0.65361157, + 0.65361157, + 0.65361157, + 0.6505323, + 0.6505323, + 0.6505323, + 0.6505323, + 0.64740399, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64423965, + 0.64077498, + 0.64077498, + 0.64077498, + 0.64077498, + 0.64077498, + 0.64077498, + 0.64077498, + 0.64077498, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63715085, + 0.63321926, + 0.63321926, + 0.63321926, + 0.63321926, + 0.63321926, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62511033, + 0.62047205, + 0.61579389, + 0.61107752, + 0.60632446, + 0.60153609, + 0.59671368, + 0.59185836, + 0.5869712, + 0.58205316, + 0.57710514, + 0.57212797, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56712238, + 0.56107571, + 0.56107571, + 0.56107571, + 0.56107571, + 0.56107571, + 0.56107571, + 0.56107571, + 0.56107571, + 0.56107571, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55440125, + 0.55876195, + 0.55876195, + 0.55876195, + 0.55876195, + 0.55876195, + 0.55876195, + 0.55876195, + 0.55876195, + 0.55876195, + 0.60953938, + 0.45577173, + 0.0, + ], + ] + ) + return 0.99, "greenwood", true_ci + @pytest.fixture() def make_channing(): @@ -2817,30 +4590,34 @@ def whas500_true_x(): class TestKaplanMeier: @staticmethod - @pytest.mark.parametrize("time,event,true_x,true_y,km_ci", SimpleDataKMCases().get_cases()) - def test_simple(time, event, true_x, true_y, km_ci): - x, y = kaplan_meier_estimator(event, time) - - assert_array_equal(x, true_x) - assert_array_almost_equal(y, true_y) + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + @pytest.mark.parametrize( + "case", ["all_censored", "all_uncensored", "first_censored", "last_censored", "first_and_last_censored"] + ) + def test_simple(case, conf_type): + time, event, true_x, true_y, km_ci_true = SimpleDataKMCases(conf_type=conf_type).case_data(case=case) + x, y, km_ci = kaplan_meier_estimator(event, time, conf_type=conf_type) - x, y, ci = kaplan_meier_estimator(event, time, conf_type="log-log") assert_array_equal(x, true_x) assert_array_almost_equal(y, true_y) - assert_array_almost_equal(ci, km_ci) + assert_array_almost_equal(km_ci, km_ci_true) ys = Surv.from_arrays(event, time) - est = SurvivalFunctionEstimator(conf_type="log-log").fit(ys) + est = SurvivalFunctionEstimator(conf_type=conf_type).fit(ys) assert_array_equal(est.unique_time_[1:], true_x) assert_array_almost_equal(est.prob_[1:], true_y) - assert_array_almost_equal(est.conf_int_[:, 1:], km_ci) + assert_array_almost_equal(est.conf_int_[:, 1:], km_ci_true) prob, ci = est.predict_proba(true_x, return_conf_int=True) assert_array_almost_equal(prob, true_y) - assert_array_almost_equal(ci, km_ci) + assert_array_almost_equal(ci, km_ci_true) @staticmethod - @pytest.mark.parametrize("time,event,true_x,true_y,km_var", SimpleDataKMCases().get_cases()) - def test_wrong_dtype(time, event, true_x, true_y, km_var): # noqa: F821 undefined name + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + @pytest.mark.parametrize( + "case", ["all_censored", "all_uncensored", "first_censored", "last_censored", "first_and_last_censored"] + ) + def test_wrong_dtype(case, conf_type): # noqa: F821 undefined name + time, event, true_x, true_y, km_ci_true = SimpleDataKMCases(conf_type=conf_type).case_data(case=case) ys = Surv.from_arrays(event, time) est = SurvivalFunctionEstimator().fit(ys) with pytest.raises(ValueError, match="dtype='numeric' is not compatible with arrays of bytes/strings"): @@ -2851,21 +4628,23 @@ def test_wrong_dtype(time, event, true_x, true_y, km_var): # noqa: F821 undefin @staticmethod @pytest.mark.parametrize("conf_level", [None, -1, 1.0, 3.0, np.inf, np.nan]) - def test_estimator_invalid_conf_level(random_survival_data, conf_level): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_estimator_invalid_conf_level(random_survival_data, conf_level, conf_type): msg = r"The 'conf_level' parameter of SurvivalFunctionEstimator must be a float in the range \(0\.0, 1\.0\)\." with pytest.raises(ValueError, match=msg): - SurvivalFunctionEstimator(conf_level=conf_level, conf_type="log-log").fit(random_survival_data) + SurvivalFunctionEstimator(conf_level=conf_level, conf_type=conf_type).fit(random_survival_data) @staticmethod @pytest.mark.parametrize("conf_level", [None, -1, 1.0, 3.0, np.inf, np.nan]) - def test_invalid_conf_level(random_survival_data, conf_level): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_invalid_conf_level(random_survival_data, conf_level, conf_type): msg = r"conf_level must be a float in the range \(0\.0, 1\.0\)" with pytest.raises(ValueError, match=msg): kaplan_meier_estimator( random_survival_data["event"], random_survival_data["time"], conf_level=conf_level, - conf_type="log-log", + conf_type=conf_type, ) @staticmethod @@ -3322,12 +5101,13 @@ def test_ci_whas500(make_whas500, conf_level, conf_type, true_ci): assert_array_almost_equal(ci, true_ci) @staticmethod - def test_left_truncated_simple1(): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_left_truncated_simple1(conf_type): time_enter = np.array([4, 3, 4, 2, 1, 5, 6, 7]) time_exit = np.array([6, 7, 8, 10, 9, 9, 10, 11]) event = np.array([0, 1, 0, 0, 1, 1, 1, 0], dtype=bool) - x, y, km_ci = kaplan_meier_estimator(event, time_exit, time_enter, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event, time_exit, time_enter, conf_type=conf_type) true_x = np.arange(1, 12) assert_array_almost_equal(x, true_x) @@ -3335,21 +5115,32 @@ def test_left_truncated_simple1(): true_y = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.833333, 0.833333, 0.5, 0.333333, 0.333333]) assert_array_almost_equal(y, true_y) - true_ci = np.array( + true_ci_logmlog = np.array( [ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.273123, 0.273123, 0.110948, 0.046082, 0.046082], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.974712, 0.974712, 0.803709, 0.675564, 0.675564], ] ) - assert_array_almost_equal(km_ci, true_ci) + true_ci_greenwood = np.array( + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.53513431, 0.53513431, 0.09992403, -0.04386191, -0.04386191], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.13153236, 1.13153236, 0.90007597, 0.71052858, 0.71052858], + ] + ) + + if conf_type == "log-log": + assert_array_almost_equal(km_ci, true_ci_logmlog) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci, true_ci_greenwood) @staticmethod - def test_left_truncated_simple2(): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_left_truncated_simple2(conf_type): time_enter = np.array([1, 4, 3, 4, 2, 1, 5, 6, 7]) time_exit = np.array([4, 6, 7, 8, 10, 9, 9, 10, 11]) event = np.array([1, 0, 1, 0, 0, 1, 1, 1, 0], dtype=bool) - x, y, km_ci = kaplan_meier_estimator(event, time_exit, time_enter, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event, time_exit, time_enter, conf_type=conf_type) true_x = np.arange(1, 12) assert_array_almost_equal(x, true_x) @@ -3357,21 +5148,56 @@ def test_left_truncated_simple2(): true_y = np.array([1.0, 1.0, 1.0, 0.75, 0.75, 0.75, 0.625, 0.625, 0.375, 0.25, 0.25]) assert_array_almost_equal(y, true_y) - true_ci = np.array( + true_ci_logmlog = np.array( [ [1.0, 1.0, 1.0, 0.127947, 0.127947, 0.127947, 0.141853, 0.141853, 0.069678, 0.03165, 0.03165], [1.0, 1.0, 1.0, 0.960549, 0.960549, 0.960549, 0.893051, 0.893051, 0.696882, 0.573177, 0.573177], ] ) - assert_array_almost_equal(km_ci, true_ci) + true_ci_greenwood = np.array( + [ + [ + 1.0, + 1.0, + 1.0, + 0.32565535, + 0.32565535, + 0.32565535, + 0.20659053, + 0.20659053, + 0.00750675, + -0.06628783, + -0.06628783, + ], + [ + 1.0, + 1.0, + 1.0, + 1.17434465, + 1.17434465, + 1.17434465, + 1.04340947, + 1.04340947, + 0.74249325, + 0.56628783, + 0.56628783, + ], + ] + ) + + if conf_type == "log-log": + assert_array_almost_equal(km_ci, true_ci_logmlog) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci, true_ci_greenwood) @staticmethod - def test_left_truncated_simple3(): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_left_truncated_simple3(conf_type): time_enter = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 9]) time_exit = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 19]) event = np.array([0, 1, 0, 0, 1, 1, 1, 1, 0, 1], dtype=bool) - x, y, km_ci = kaplan_meier_estimator(event, time_exit, time_enter, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event, time_exit, time_enter, conf_type=conf_type) true_x = np.concatenate((np.arange(1, 10), np.arange(11, 20))) assert_array_almost_equal(x, true_x) @@ -3379,49 +5205,66 @@ def test_left_truncated_simple3(): true_y = np.array([1, 0.888889, 0.888889, 0.888889, 0.740741, 0.592593, 0.444444, 0.296296, 0.148148]) assert_array_almost_equal(y[9:], true_y) - true_ci = np.array( + true_ci_logmlog = np.array( [ [1.0, 0.432965, 0.432965, 0.432965, 0.289212, 0.185919, 0.103865, 0.043129, 0.00736], [1.0, 0.983564, 0.983564, 0.983564, 0.929976, 0.849818, 0.747981, 0.624575, 0.475977], ] ) - assert_array_almost_equal(km_ci[:, :9], np.ones((2, 9))) - assert_array_almost_equal(km_ci[:, 9:], true_ci) + + true_ci_greenwood = np.array( + [ + [1.0, 0.6835698, 0.6835698, 0.6835698, 0.42524934, 0.23044377, 0.0743002, -0.04590218, -0.1191174], + [1.0, 1.09420797, 1.09420797, 1.09420797, 1.05623215, 0.95474142, 0.81458869, 0.63849477, 0.4154137], + ] + ) + + if conf_type == "log-log": + assert_array_almost_equal(km_ci[:, :9], np.ones((2, 9))) + assert_array_almost_equal(km_ci[:, 9:], true_ci_logmlog) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci[:, :9], np.ones((2, 9))) + assert_array_almost_equal(km_ci[:, 9:], true_ci_greenwood) @staticmethod - def test_truncated_male(make_channing, channing_male_true_x): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_truncated_male(make_channing, channing_male_true_x, conf_type): time_enter_m, time_exit_m, event_m = make_channing("Male") - x, y, km_ci = kaplan_meier_estimator(event_m, time_exit_m, time_enter_m, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event_m, time_exit_m, time_enter_m, conf_type=conf_type) assert_array_equal(x, channing_male_true_x) assert_array_equal(y[:3], np.array([1.0, 1.0, 0.5])) assert (y[3:] == 0).all() - assert_array_almost_equal( - km_ci[:, :3], - np.array( + true_ci_logmlog = np.array( + [ + [ + 1.0, + 1.0, + 0.00598309, + ], [ - [ - 1.0, - 1.0, - 0.00598309, - ], - [ - 1.0, - 1.0, - 0.91041008, - ], - ] - ), + 1.0, + 1.0, + 0.91041008, + ], + ] ) + true_ci_greenwood = np.array([[1.0, 1.0, -0.19295191], [1.0, 1.0, 1.19295191]]) + + if conf_type == "log-log": + assert_array_almost_equal(km_ci[:, :3], true_ci_logmlog) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci[:, :3], true_ci_greenwood) @staticmethod - def test_truncated_male_older_68(make_channing, channing_male_true_x): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_truncated_male_older_68(make_channing, channing_male_true_x, conf_type): time_enter_m, time_exit_m, event_m = make_channing("Male") - x, y, km_ci = kaplan_meier_estimator(event_m, time_exit_m, time_enter_m, time_min=68 * 12, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event_m, time_exit_m, time_enter_m, time_min=68 * 12, conf_type=conf_type) x_true = channing_male_true_x[6:] @@ -3549,7 +5392,7 @@ def test_truncated_male_older_68(make_channing, channing_male_true_x): ) assert_array_almost_equal(y[18:], y_true) - km_var_true = np.array( + km_ci_true_logmlog = np.array( [ [ 0.73920615, @@ -3786,14 +5629,255 @@ def test_truncated_male_older_68(make_channing, channing_male_true_x): ] ) - assert_array_almost_equal(km_ci[:, :18], np.ones((2, 18))) - assert_array_almost_equal(km_ci[:, 18:], km_var_true) + km_ci_true_greenwood = np.array( + [ + [ + 0.87838763, + 0.87838763, + 0.81365502, + 0.81365502, + 0.75903569, + 0.75903569, + 0.75903569, + 0.75903569, + 0.75903569, + 0.75903569, + 0.75903569, + 0.75903569, + 0.72542183, + 0.69392151, + 0.69392151, + 0.66308011, + 0.66308011, + 0.66308011, + 0.63617431, + 0.60965074, + 0.5844797, + 0.5844797, + 0.5844797, + 0.5844797, + 0.5844797, + 0.5844797, + 0.5844797, + 0.5844797, + 0.56371736, + 0.54312226, + 0.54312226, + 0.54312226, + 0.54312226, + 0.54312226, + 0.52338793, + 0.52338793, + 0.50379951, + 0.50379951, + 0.50379951, + 0.50379951, + 0.50379951, + 0.48567218, + 0.48567218, + 0.48567218, + 0.48567218, + 0.48567218, + 0.46884015, + 0.46884015, + 0.45316342, + 0.45316342, + 0.43805743, + 0.43805743, + 0.43805743, + 0.43805743, + 0.43805743, + 0.43805743, + 0.43805743, + 0.4239445, + 0.4239445, + 0.41032046, + 0.41032046, + 0.39716531, + 0.37100274, + 0.37100274, + 0.35754855, + 0.35754855, + 0.35754855, + 0.35754855, + 0.35754855, + 0.35754855, + 0.34255878, + 0.34255878, + 0.31508576, + 0.31508576, + 0.31508576, + 0.31508576, + 0.31508576, + 0.31508576, + 0.31508576, + 0.30024919, + 0.30024919, + 0.28480293, + 0.28480293, + 0.26947371, + 0.25426555, + 0.23734148, + 0.22060452, + 0.22060452, + 0.22060452, + 0.20688019, + 0.19239852, + 0.19239852, + 0.19239852, + 0.19239852, + 0.19239852, + 0.17702918, + 0.16185167, + 0.16185167, + 0.14555061, + 0.12953106, + 0.12953106, + 0.12953106, + 0.12953106, + 0.12953106, + 0.10980479, + 0.09071042, + 0.09071042, + 0.04839848, + 0.04839848, + 0.04839848, + 0.04839848, + -0.00489106, + -0.03698163, + -0.03698163, + ], + [ + 1.03827904, + 1.03827904, + 1.02634498, + 1.02634498, + 1.00736431, + 1.00736431, + 1.00736431, + 1.00736431, + 1.00736431, + 1.00736431, + 1.00736431, + 1.00736431, + 0.9874509, + 0.96704598, + 0.96704598, + 0.94598215, + 0.94598215, + 0.94598215, + 0.92556258, + 0.90476079, + 0.88404058, + 0.88404058, + 0.88404058, + 0.88404058, + 0.88404058, + 0.88404058, + 0.88404058, + 0.88404058, + 0.86511318, + 0.84601855, + 0.84601855, + 0.84601855, + 0.84601855, + 0.84601855, + 0.82716563, + 0.82716563, + 0.8081668, + 0.8081668, + 0.8081668, + 0.8081668, + 0.8081668, + 0.78985063, + 0.78985063, + 0.78985063, + 0.78985063, + 0.78985063, + 0.77220907, + 0.77220907, + 0.75522661, + 0.75522661, + 0.73853286, + 0.73853286, + 0.73853286, + 0.73853286, + 0.73853286, + 0.73853286, + 0.73853286, + 0.72247681, + 0.72247681, + 0.70670543, + 0.70670543, + 0.69121889, + 0.66009809, + 0.66009809, + 0.64409224, + 0.64409224, + 0.64409224, + 0.64409224, + 0.64409224, + 0.64409224, + 0.62677102, + 0.62677102, + 0.59366093, + 0.59366093, + 0.59366093, + 0.59366093, + 0.59366093, + 0.59366093, + 0.59366093, + 0.57604226, + 0.57604226, + 0.55778501, + 0.55778501, + 0.53941071, + 0.52091535, + 0.50092605, + 0.48074963, + 0.48074963, + 0.48074963, + 0.46259422, + 0.44360217, + 0.44360217, + 0.44360217, + 0.44360217, + 0.44360217, + 0.42363815, + 0.40348228, + 0.40348228, + 0.38209441, + 0.36042503, + 0.36042503, + 0.36042503, + 0.36042503, + 0.36042503, + 0.33560984, + 0.31016274, + 0.31016274, + 0.25225639, + 0.25225639, + 0.25225639, + 0.25225639, + 0.20532764, + 0.13719993, + 0.13719993, + ], + ] + ) + if conf_type == "log-log": + assert_array_almost_equal(km_ci[:, :18], np.ones((2, 18))) + assert_array_almost_equal(km_ci[:, 18:], km_ci_true_logmlog) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci[:, :18], np.ones((2, 18))) + assert_array_almost_equal(km_ci[:, 18:], km_ci_true_greenwood) @staticmethod - def test_truncated_female(make_channing, channing_female_true_x): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_truncated_female(make_channing, channing_female_true_x, conf_type): time_enter_f, time_exit_f, event_f = make_channing("Female") - x, y, km_ci = kaplan_meier_estimator(event_f, time_exit_f, time_enter_f, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event_f, time_exit_f, time_enter_f, conf_type=conf_type) assert_array_equal(x, channing_female_true_x) @@ -4081,7 +6165,7 @@ def test_truncated_female(make_channing, channing_female_true_x): ) assert_array_almost_equal(y[19:], y_true) - km_ci_true = np.array( + km_ci_true_logmlog = np.array( [ [ 1, @@ -4679,13 +6763,617 @@ def test_truncated_female(make_channing, channing_female_true_x): ], ] ) - assert_array_almost_equal(km_ci, km_ci_true) + + km_ci_true_greenwood = np.array( + [ + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.86129862, + 0.82367457, + 0.82367457, + 0.82367457, + 0.82367457, + 0.82367457, + 0.82367457, + 0.82367457, + 0.79842715, + 0.79842715, + 0.79842715, + 0.79842715, + 0.79842715, + 0.79842715, + 0.79842715, + 0.79842715, + 0.79842715, + 0.78038478, + 0.78038478, + 0.78038478, + 0.78038478, + 0.78038478, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.7654239, + 0.75516121, + 0.75516121, + 0.75516121, + 0.75516121, + 0.75516121, + 0.75516121, + 0.74622254, + 0.74622254, + 0.74622254, + 0.74622254, + 0.74622254, + 0.7379748, + 0.7379748, + 0.7379748, + 0.7379748, + 0.7379748, + 0.7379748, + 0.7379748, + 0.7379748, + 0.73062434, + 0.72356096, + 0.72356096, + 0.72356096, + 0.72356096, + 0.72356096, + 0.72356096, + 0.72356096, + 0.72356096, + 0.72356096, + 0.72356096, + 0.71759446, + 0.71759446, + 0.7118213, + 0.7118213, + 0.7118213, + 0.7118213, + 0.70631623, + 0.70631623, + 0.70631623, + 0.69572266, + 0.69572266, + 0.69572266, + 0.68521825, + 0.68521825, + 0.68521825, + 0.68016149, + 0.68016149, + 0.68016149, + 0.68016149, + 0.67510691, + 0.67510691, + 0.67510691, + 0.67510691, + 0.67012834, + 0.67012834, + 0.67012834, + 0.67012834, + 0.66546096, + 0.66546096, + 0.66546096, + 0.66089057, + 0.66089057, + 0.65625824, + 0.65625824, + 0.65165974, + 0.64709479, + 0.64256313, + 0.64256313, + 0.63800115, + 0.63800115, + 0.63374269, + 0.63374269, + 0.63374269, + 0.62945694, + 0.62514337, + 0.62514337, + 0.62514337, + 0.61691444, + 0.61691444, + 0.61691444, + 0.61691444, + 0.61268993, + 0.61268993, + 0.61268993, + 0.61268993, + 0.61268993, + 0.60846672, + 0.60846672, + 0.60846672, + 0.60846672, + 0.60846672, + 0.60433024, + 0.60433024, + 0.60433024, + 0.60433024, + 0.60010901, + 0.60010901, + 0.60010901, + 0.59585948, + 0.59585948, + 0.59585948, + 0.58781933, + 0.58380097, + 0.58380097, + 0.58380097, + 0.58380097, + 0.57975614, + 0.57565576, + 0.57565576, + 0.57152749, + 0.57152749, + 0.57152749, + 0.56333356, + 0.55923854, + 0.55923854, + 0.55923854, + 0.55502243, + 0.55502243, + 0.55502243, + 0.55077584, + 0.53354211, + 0.52923785, + 0.52493529, + 0.52493529, + 0.51626503, + 0.507602, + 0.49451092, + 0.49451092, + 0.49011274, + 0.48567725, + 0.48124389, + 0.47677197, + 0.47677197, + 0.47226048, + 0.46775134, + 0.4632871, + 0.45427993, + 0.45427993, + 0.45427993, + 0.45427993, + 0.44964394, + 0.44491449, + 0.43526251, + 0.43027822, + 0.42540884, + 0.42037272, + 0.42037272, + 0.41534066, + 0.40528888, + 0.3952539, + 0.38997985, + 0.38456739, + 0.38456739, + 0.37360904, + 0.36797759, + 0.36797759, + 0.36243616, + 0.36243616, + 0.35706236, + 0.35161447, + 0.35161447, + 0.35161447, + 0.34582223, + 0.34582223, + 0.34582223, + 0.34582223, + 0.34582223, + 0.34582223, + 0.34582223, + 0.32155524, + 0.30346442, + 0.30346442, + 0.29745582, + 0.29133837, + 0.28510615, + 0.28510615, + 0.28510615, + 0.28510615, + 0.28510615, + 0.28510615, + 0.28510615, + 0.27053674, + 0.27053674, + 0.27053674, + 0.27053674, + 0.27053674, + 0.26245652, + 0.25416792, + 0.25416792, + 0.23718352, + 0.22875579, + 0.22875579, + 0.22037187, + 0.2117232, + 0.20312553, + 0.20312553, + 0.19422819, + 0.18539122, + 0.16790377, + 0.15882236, + 0.15882236, + 0.14932295, + 0.14932295, + 0.13934521, + 0.13934521, + 0.13008091, + 0.13008091, + 0.13008091, + 0.13008091, + 0.13008091, + 0.11674875, + 0.11674875, + 0.10237793, + 0.08844233, + 0.07495408, + 0.07495408, + 0.07495408, + 0.06193456, + 0.06193456, + 0.04714269, + 0.03321498, + 0.03321498, + 0.0094956, + -0.02025011, + -0.02025011, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.04346329, + 1.02817728, + 1.02817728, + 1.02817728, + 1.02817728, + 1.02817728, + 1.02817728, + 1.02817728, + 1.01316705, + 1.01316705, + 1.01316705, + 1.01316705, + 1.01316705, + 1.01316705, + 1.01316705, + 1.01316705, + 1.01316705, + 0.99997504, + 0.99997504, + 0.99997504, + 0.99997504, + 0.99997504, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.98796078, + 0.97852251, + 0.97852251, + 0.97852251, + 0.97852251, + 0.97852251, + 0.97852251, + 0.96994923, + 0.96994923, + 0.96994923, + 0.96994923, + 0.96994923, + 0.96185247, + 0.96185247, + 0.96185247, + 0.96185247, + 0.96185247, + 0.96185247, + 0.96185247, + 0.96185247, + 0.95442183, + 0.94720515, + 0.94720515, + 0.94720515, + 0.94720515, + 0.94720515, + 0.94720515, + 0.94720515, + 0.94720515, + 0.94720515, + 0.94720515, + 0.94088661, + 0.94088661, + 0.93472824, + 0.93472824, + 0.93472824, + 0.93472824, + 0.92879895, + 0.92879895, + 0.92879895, + 0.91729636, + 0.91729636, + 0.91729636, + 0.90585494, + 0.90585494, + 0.90585494, + 0.90030455, + 0.90030455, + 0.90030455, + 0.90030455, + 0.89475197, + 0.89475197, + 0.89475197, + 0.89475197, + 0.88926482, + 0.88926482, + 0.88926482, + 0.88926482, + 0.88406262, + 0.88406262, + 0.88406262, + 0.87894849, + 0.87894849, + 0.87377292, + 0.87377292, + 0.868626, + 0.8635076, + 0.85841759, + 0.85841759, + 0.85329582, + 0.85329582, + 0.84846101, + 0.84846101, + 0.84846101, + 0.84359736, + 0.83870433, + 0.83870433, + 0.83870433, + 0.82929655, + 0.82929655, + 0.82929655, + 0.82929655, + 0.82448224, + 0.82448224, + 0.82448224, + 0.82448224, + 0.82448224, + 0.81966663, + 0.81966663, + 0.81966663, + 0.81966663, + 0.81966663, + 0.81493272, + 0.81493272, + 0.81493272, + 0.81493272, + 0.81011405, + 0.81011405, + 0.81011405, + 0.80526537, + 0.80526537, + 0.80526537, + 0.79600768, + 0.79137713, + 0.79137713, + 0.79137713, + 0.79137713, + 0.7867183, + 0.78200271, + 0.78200271, + 0.7772574, + 0.7772574, + 0.7772574, + 0.76782015, + 0.76309958, + 0.76309958, + 0.76309958, + 0.75825858, + 0.75825858, + 0.75825858, + 0.75338516, + 0.73362142, + 0.72867631, + 0.7237295, + 0.7237295, + 0.71376297, + 0.70378922, + 0.68870841, + 0.68870841, + 0.68364083, + 0.67853361, + 0.67342426, + 0.66827395, + 0.66827395, + 0.66308166, + 0.65788702, + 0.65273042, + 0.64232858, + 0.64232858, + 0.64232858, + 0.64232858, + 0.63699541, + 0.63156935, + 0.62051972, + 0.61483954, + 0.60925775, + 0.60351608, + 0.60351608, + 0.59777037, + 0.5862666, + 0.57474602, + 0.56874101, + 0.56260261, + 0.56260261, + 0.55017405, + 0.54380831, + 0.54380831, + 0.53750837, + 0.53750837, + 0.53134441, + 0.52510274, + 0.52510274, + 0.52510274, + 0.51854686, + 0.51854686, + 0.51854686, + 0.51854686, + 0.51854686, + 0.51854686, + 0.51854686, + 0.49120972, + 0.47059744, + 0.47059744, + 0.46370502, + 0.456699, + 0.44957341, + 0.44957341, + 0.44957341, + 0.44957341, + 0.44957341, + 0.44957341, + 0.44957341, + 0.43353117, + 0.43353117, + 0.43353117, + 0.43353117, + 0.43353117, + 0.42484787, + 0.41595386, + 0.41595386, + 0.39766869, + 0.38846164, + 0.38846164, + 0.37921078, + 0.36969028, + 0.36011877, + 0.36011877, + 0.3502413, + 0.34030346, + 0.32024129, + 0.3097969, + 0.3097969, + 0.29892156, + 0.29892156, + 0.28755433, + 0.28755433, + 0.27649008, + 0.27649008, + 0.27649008, + 0.27649008, + 0.27649008, + 0.2627175, + 0.2627175, + 0.24789861, + 0.2326445, + 0.21694304, + 0.21694304, + 0.21694304, + 0.20077284, + 0.20077284, + 0.18272629, + 0.16381557, + 0.16381557, + 0.13827732, + 0.06950775, + 0.06950775, + ], + ] + ) + + if conf_type == "log-log": + assert_array_almost_equal(km_ci, km_ci_true_logmlog) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci, km_ci_true_greenwood) @staticmethod - def test_truncated_female_older_68(make_channing, channing_female_true_x): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_truncated_female_older_68(make_channing, channing_female_true_x, conf_type): time_enter_f, time_exit_f, event_f = make_channing("Female") - x, y, km_ci = kaplan_meier_estimator(event_f, time_exit_f, time_enter_f, time_min=68 * 12, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event_f, time_exit_f, time_enter_f, time_min=68 * 12, conf_type=conf_type) x_true = channing_female_true_x[30:] assert_array_equal(x, x_true) @@ -4961,7 +7649,7 @@ def test_truncated_female_older_68(make_channing, channing_female_true_x): ) assert_array_almost_equal(y, y_true) - km_var_true = np.array( + km_ci_true_logmlog = np.array( [ [ 1.0, @@ -5499,13 +8187,557 @@ def test_truncated_female_older_68(make_channing, channing_female_true_x): ], ] ) - assert_array_almost_equal(km_ci, km_var_true) + + km_ci_true_greenwood = np.array( + [ + [ + 1.0, + 1.0, + 1.0, + 1.0, + 0.91854026, + 0.91854026, + 0.91854026, + 0.91854026, + 0.91854026, + 0.91854026, + 0.91854026, + 0.88447994, + 0.88447994, + 0.88447994, + 0.88447994, + 0.88447994, + 0.88447994, + 0.88447994, + 0.88447994, + 0.88447994, + 0.86188798, + 0.86188798, + 0.86188798, + 0.86188798, + 0.86188798, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.84371967, + 0.83161, + 0.83161, + 0.83161, + 0.83161, + 0.83161, + 0.83161, + 0.82117042, + 0.82117042, + 0.82117042, + 0.82117042, + 0.82117042, + 0.81159736, + 0.81159736, + 0.81159736, + 0.81159736, + 0.81159736, + 0.81159736, + 0.81159736, + 0.81159736, + 0.80312128, + 0.79500168, + 0.79500168, + 0.79500168, + 0.79500168, + 0.79500168, + 0.79500168, + 0.79500168, + 0.79500168, + 0.79500168, + 0.79500168, + 0.78818923, + 0.78818923, + 0.78161046, + 0.78161046, + 0.78161046, + 0.78161046, + 0.77535114, + 0.77535114, + 0.77535114, + 0.76333242, + 0.76333242, + 0.76333242, + 0.75143412, + 0.75143412, + 0.75143412, + 0.74571744, + 0.74571744, + 0.74571744, + 0.74571744, + 0.74000675, + 0.74000675, + 0.74000675, + 0.74000675, + 0.73438726, + 0.73438726, + 0.73438726, + 0.73438726, + 0.7291297, + 0.7291297, + 0.7291297, + 0.72398618, + 0.72398618, + 0.71877378, + 0.71877378, + 0.71360263, + 0.70847229, + 0.70338236, + 0.70338236, + 0.69825985, + 0.69825985, + 0.69348674, + 0.69348674, + 0.69348674, + 0.68868417, + 0.68385158, + 0.68385158, + 0.68385158, + 0.67464554, + 0.67464554, + 0.67464554, + 0.67464554, + 0.66991938, + 0.66991938, + 0.66991938, + 0.66991938, + 0.66991938, + 0.66519632, + 0.66519632, + 0.66519632, + 0.66519632, + 0.66519632, + 0.66057357, + 0.66057357, + 0.66057357, + 0.66057357, + 0.6558559, + 0.6558559, + 0.6558559, + 0.65110762, + 0.65110762, + 0.65110762, + 0.64213697, + 0.63765556, + 0.63765556, + 0.63765556, + 0.63765556, + 0.63314545, + 0.6285737, + 0.6285737, + 0.62397174, + 0.62397174, + 0.62397174, + 0.61484299, + 0.6102829, + 0.6102829, + 0.6102829, + 0.60558714, + 0.60558714, + 0.60558714, + 0.6008584, + 0.58167954, + 0.57689359, + 0.57211115, + 0.57211115, + 0.56247759, + 0.55285866, + 0.53833369, + 0.53833369, + 0.5334565, + 0.52853902, + 0.52362566, + 0.51867066, + 0.51867066, + 0.51367296, + 0.5086797, + 0.50373863, + 0.49377342, + 0.49377342, + 0.49377342, + 0.49377342, + 0.48864504, + 0.48341384, + 0.47274142, + 0.46723124, + 0.46185213, + 0.45628873, + 0.45628873, + 0.45073238, + 0.43964094, + 0.42857804, + 0.42276405, + 0.41679854, + 0.41679854, + 0.40472815, + 0.39852814, + 0.39852814, + 0.39243176, + 0.39243176, + 0.38652495, + 0.38053872, + 0.38053872, + 0.38053872, + 0.37417312, + 0.37417312, + 0.37417312, + 0.37417312, + 0.37417312, + 0.37417312, + 0.37417312, + 0.3475315, + 0.32771438, + 0.32771438, + 0.32114042, + 0.3144502, + 0.30763747, + 0.30763747, + 0.30763747, + 0.30763747, + 0.30763747, + 0.30763747, + 0.30763747, + 0.29170676, + 0.29170676, + 0.29170676, + 0.29170676, + 0.29170676, + 0.28287569, + 0.27382464, + 0.27382464, + 0.255305, + 0.24613086, + 0.24613086, + 0.23701405, + 0.22761703, + 0.21828533, + 0.21828533, + 0.20863634, + 0.19906302, + 0.18014703, + 0.17033631, + 0.17033631, + 0.16008278, + 0.16008278, + 0.14932295, + 0.14932295, + 0.13934521, + 0.13934521, + 0.13934521, + 0.13934521, + 0.13934521, + 0.12499241, + 0.12499241, + 0.10954867, + 0.09460088, + 0.08015671, + 0.08015671, + 0.08015671, + 0.0662347, + 0.0662347, + 0.05044174, + 0.03559316, + 0.03559316, + 0.01037872, + -0.02119767, + -0.02119767, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.02590418, + 1.02590418, + 1.02590418, + 1.02590418, + 1.02590418, + 1.02590418, + 1.02590418, + 1.01769397, + 1.01769397, + 1.01769397, + 1.01769397, + 1.01769397, + 1.01769397, + 1.01769397, + 1.01769397, + 1.01769397, + 1.00748983, + 1.00748983, + 1.00748983, + 1.00748983, + 1.00748983, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.99733424, + 0.98875791, + 0.98875791, + 0.98875791, + 0.98875791, + 0.98875791, + 0.98875791, + 0.98080993, + 0.98080993, + 0.98080993, + 0.98080993, + 0.98080993, + 0.97322128, + 0.97322128, + 0.97322128, + 0.97322128, + 0.97322128, + 0.97322128, + 0.97322128, + 0.97322128, + 0.96617719, + 0.95930273, + 0.95930273, + 0.95930273, + 0.95930273, + 0.95930273, + 0.95930273, + 0.95930273, + 0.95930273, + 0.95930273, + 0.95930273, + 0.95321589, + 0.95321589, + 0.94726657, + 0.94726657, + 0.94726657, + 0.94726657, + 0.94151979, + 0.94151979, + 0.94151979, + 0.93033756, + 0.93033756, + 0.93033756, + 0.91919273, + 0.91919273, + 0.91919273, + 0.91377189, + 0.91377189, + 0.91377189, + 0.91377189, + 0.90834507, + 0.90834507, + 0.90834507, + 0.90834507, + 0.90297555, + 0.90297555, + 0.90297555, + 0.90297555, + 0.89787006, + 0.89787006, + 0.89787006, + 0.89284483, + 0.89284483, + 0.88775894, + 0.88775894, + 0.8826974, + 0.87766022, + 0.8726474, + 0.8726474, + 0.86760197, + 0.86760197, + 0.86282714, + 0.86282714, + 0.86282714, + 0.85802284, + 0.85318851, + 0.85318851, + 0.85318851, + 0.843876, + 0.843876, + 0.843876, + 0.843876, + 0.83911139, + 0.83911139, + 0.83911139, + 0.83911139, + 0.83911139, + 0.83434369, + 0.83434369, + 0.83434369, + 0.83434369, + 0.83434369, + 0.82965253, + 0.82965253, + 0.82965253, + 0.82965253, + 0.82487832, + 0.82487832, + 0.82487832, + 0.82007347, + 0.82007347, + 0.82007347, + 0.81088139, + 0.80628144, + 0.80628144, + 0.80628144, + 0.80628144, + 0.80165271, + 0.79696769, + 0.79696769, + 0.79225239, + 0.79225239, + 0.79225239, + 0.7828684, + 0.77817213, + 0.77817213, + 0.77817213, + 0.77335792, + 0.77335792, + 0.77335792, + 0.76851065, + 0.74884217, + 0.74391628, + 0.73898688, + 0.73898688, + 0.72905181, + 0.71910212, + 0.70404661, + 0.70404661, + 0.69898475, + 0.69388238, + 0.6887759, + 0.68362755, + 0.68362755, + 0.67843629, + 0.67324057, + 0.66807977, + 0.65766552, + 0.65766552, + 0.65766552, + 0.65766552, + 0.65232628, + 0.64689419, + 0.63582991, + 0.63014241, + 0.62454778, + 0.61879451, + 0.61879451, + 0.6130342, + 0.60149231, + 0.58992188, + 0.58389285, + 0.57772995, + 0.57772995, + 0.56524409, + 0.55884705, + 0.55884705, + 0.55250999, + 0.55250999, + 0.54630217, + 0.54001436, + 0.54001436, + 0.54001436, + 0.53341442, + 0.53341442, + 0.53341442, + 0.53341442, + 0.53341442, + 0.53341442, + 0.53341442, + 0.5058717, + 0.48505058, + 0.48505058, + 0.47807845, + 0.47098904, + 0.46377607, + 0.46377607, + 0.46377607, + 0.46377607, + 0.46377607, + 0.46377607, + 0.46377607, + 0.44756455, + 0.44756455, + 0.44756455, + 0.44756455, + 0.44756455, + 0.43879393, + 0.42980323, + 0.42980323, + 0.41128983, + 0.40194744, + 0.40194744, + 0.39254773, + 0.38286712, + 0.37312119, + 0.37312119, + 0.36305663, + 0.3529164, + 0.33240529, + 0.32171391, + 0.32171391, + 0.31057395, + 0.31057395, + 0.29892156, + 0.29892156, + 0.28755433, + 0.28755433, + 0.28755433, + 0.28755433, + 0.28755433, + 0.27344716, + 0.27344716, + 0.2582417, + 0.24254029, + 0.22633527, + 0.22633527, + 0.22633527, + 0.20960808, + 0.20960808, + 0.19092069, + 0.17128893, + 0.17128893, + 0.14478284, + 0.07291819, + 0.07291819, + ], + ] + ) + + if conf_type == "log-log": + assert_array_almost_equal(km_ci, km_ci_true_logmlog) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci, km_ci_true_greenwood) @staticmethod - def test_right_truncated_children(make_aids): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_right_truncated_children(make_aids, conf_type): event, time_enter, time_exit = make_aids("children") - x, y, km_ci = kaplan_meier_estimator(event, time_exit.values, time_enter.values, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event, time_exit.values, time_enter.values, conf_type=conf_type) true_x = np.array( [ 7.75, @@ -5566,7 +8798,7 @@ def test_right_truncated_children(make_aids): ) assert_array_almost_equal(y[::-1], true_y, 4) - true_km_var = np.array( + true_km_ci_logmlog = np.array( [ [ 0.0, @@ -5623,13 +8855,74 @@ def test_right_truncated_children(make_aids): ] ) - assert_array_almost_equal(km_ci[:, ::-1], true_km_var, 6) + true_km_ci_greenwood = np.array( + [ + [ + 0.0, + -0.01454119, + -0.01282925, + -0.00273296, + 0.00855181, + 0.01444552, + 0.02446431, + 0.03675121, + 0.03675121, + 0.04756275, + 0.05407834, + 0.06290989, + 0.09202081, + 0.11113964, + 0.13323204, + 0.13323204, + 0.13323204, + 0.13323204, + 0.13323204, + 0.13323204, + 0.13323204, + 1.0, + 1.0, + 1.0, + ], + [ + 0.0, + 0.05918221, + 0.16907282, + 0.31522011, + 0.44597495, + 0.50501363, + 0.59888667, + 0.70347808, + 0.70347808, + 0.7851952, + 0.83072697, + 0.88509581, + 1.02834956, + 1.11108258, + 1.2001013, + 1.2001013, + 1.2001013, + 1.2001013, + 1.2001013, + 1.2001013, + 1.2001013, + 1.0, + 1.0, + 1.0, + ], + ] + ) + + if conf_type == "log-log": + assert_array_almost_equal(km_ci[:, ::-1], true_km_ci_logmlog, 6) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci[:, ::-1], true_km_ci_greenwood, 6) @staticmethod - def test_right_truncated_adults(make_aids): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_right_truncated_adults(make_aids, conf_type): event, time_enter, time_exit = make_aids("adults") - x, y, km_ci = kaplan_meier_estimator(event, time_exit.values, time_enter.values, conf_type="log-log") + x, y, km_ci = kaplan_meier_estimator(event, time_exit.values, time_enter.values, conf_type=conf_type) true_x = np.array( [ @@ -5705,7 +8998,7 @@ def test_right_truncated_adults(make_aids): ) assert_array_almost_equal(y[::-1], true_y, 6) - true_km_var = np.array( + true_km_ci_logmlog = np.array( [ [ 0, @@ -5776,7 +9069,81 @@ def test_right_truncated_adults(make_aids): ] ) - assert_array_almost_equal(km_ci[:, ::-1], true_km_var, 6) + true_km_ci_greenwood = np.array( + [ + [ + 0.00000000e00, + 5.37166124e-06, + 4.11938514e-04, + 4.24402072e-03, + 9.26836800e-03, + 1.44108115e-02, + 2.45505631e-02, + 2.99425118e-02, + 3.81995945e-02, + 5.21355074e-02, + 6.39045942e-02, + 7.79719385e-02, + 1.00603683e-01, + 1.10810783e-01, + 1.26611163e-01, + 1.38981694e-01, + 1.68434270e-01, + 1.92908051e-01, + 2.11161044e-01, + 2.41000297e-01, + 3.00204719e-01, + 3.49013021e-01, + 4.31696176e-01, + 4.77879159e-01, + 4.77879159e-01, + 5.39521039e-01, + 6.45827696e-01, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + ], + [ + 0.00000000e00, + 7.25274760e-03, + 9.26555384e-03, + 2.34059574e-02, + 4.02105403e-02, + 5.65089570e-02, + 8.74280188e-02, + 1.03645972e-01, + 1.27549080e-01, + 1.66989858e-01, + 1.99680991e-01, + 2.37532019e-01, + 2.97205655e-01, + 3.23638888e-01, + 3.63006722e-01, + 3.92844628e-01, + 4.61878408e-01, + 5.17626241e-01, + 5.57376047e-01, + 6.19029306e-01, + 7.31830804e-01, + 8.16188376e-01, + 9.36148941e-01, + 9.99393568e-01, + 9.99393568e-01, + 1.05138805e00, + 1.10417230e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + 1.00000000e00, + ], + ] + ) + + if conf_type == "log-log": + assert_array_almost_equal(km_ci[:, ::-1], true_km_ci_logmlog, 6) + elif conf_type == "greenwood": + assert_array_almost_equal(km_ci[:, ::-1], true_km_ci_greenwood, 6) @staticmethod def test_censoring_distribution(): @@ -5793,14 +9160,15 @@ def test_censoring_distribution(): assert_array_almost_equal(expected, probs) @staticmethod - def test_reverse_conf_int(random_survival_data): + @pytest.mark.parametrize("conf_type", ["log-log", "greenwood"]) + def test_reverse_conf_int(random_survival_data, conf_type): msg = "Confidence intervals of the censoring distribution is not implemented" with pytest.raises(NotImplementedError, match=msg): kaplan_meier_estimator( random_survival_data["event"], random_survival_data["time"], reverse=True, - conf_type="log-log", + conf_type=conf_type, ) @staticmethod