Skip to content

Commit

Permalink
Refactored calibrate_double_ended covariance calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
bdestombe committed Aug 18, 2023
1 parent b453f59 commit e884f94
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 157 deletions.
127 changes: 4 additions & 123 deletions src/dtscalibration/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,8 +2260,9 @@ def calibration_double_ended(
assert p_cov.shape == (ip.npar, ip.npar)

coords = {"x": self["x"], "time": self["time"], "trans_att": self["trans_att"]}
params = get_params_from_pval(ip, p_val, coords)
param_covs = get_params_from_pval(ip, p_var, coords)
params = get_params_from_pval(ip, coords, p_val=p_val)
param_covs = get_params_from_pval(ip, coords, p_val=p_var, p_cov=p_cov)

out = xr.Dataset(
{
"tmpf": params["gamma"]
Expand All @@ -2283,126 +2284,6 @@ def calibration_double_ended(
}
)

# extract covariances and ensure broadcastable to (nx, nt)
param_covs["gamma_df"] = (("time",), p_cov[np.ix_(ip.gamma, ip.df)][0])
param_covs["gamma_db"] = (("time",), p_cov[np.ix_(ip.gamma, ip.db)][0])
param_covs["gamma_alpha"] = (("x",), p_cov[np.ix_(ip.alpha, ip.gamma)][:, 0])
param_covs["df_db"] = (
("time",),
p_cov[ip.df, ip.db],
)
param_covs["alpha_df"] = (
(
"x",
"time",
),
p_cov[np.ix_(ip.alpha, ip.df)],
)
param_covs["alpha_db"] = (
(
"x",
"time",
),
p_cov[np.ix_(ip.alpha, ip.db)],
)
param_covs["tafw_gamma"] = (
(
"x",
"time",
),
ip.get_taf_values(
pval=p_cov[ip.gamma],
x=self.x.values,
trans_att=self.trans_att.values,
axis="",
),
)
param_covs["tabw_gamma"] = (
(
"x",
"time",
),
ip.get_tab_values(
pval=p_cov[ip.gamma],
x=self.x.values,
trans_att=self.trans_att.values,
axis="",
),
)
param_covs["tafw_alpha"] = (
(
"x",
"time",
),
ip.get_taf_values(
pval=p_cov[ip.alpha],
x=self.x.values,
trans_att=self.trans_att.values,
axis="x",
),
)
param_covs["tabw_alpha"] = (
(
"x",
"time",
),
ip.get_tab_values(
pval=p_cov[ip.alpha],
x=self.x.values,
trans_att=self.trans_att.values,
axis="x",
),
)
param_covs["tafw_df"] = (
(
"x",
"time",
),
ip.get_taf_values(
pval=p_cov[ip.df],
x=self.x.values,
trans_att=self.trans_att.values,
axis="time",
),
)
param_covs["tafw_db"] = (
(
"x",
"time",
),
ip.get_taf_values(
pval=p_cov[ip.db],
x=self.x.values,
trans_att=self.trans_att.values,
axis="time",
),
)
param_covs["tabw_db"] = (
(
"x",
"time",
),
ip.get_tab_values(
pval=p_cov[ip.db],
x=self.x.values,
trans_att=self.trans_att.values,
axis="time",
),
)
param_covs["tabw_df"] = (
(
"x",
"time",
),
ip.get_tab_values(
pval=p_cov[ip.df],
x=self.x.values,
trans_att=self.trans_att.values,
axis="time",
),
)
# sigma2_tafw_tabw

tmpf = out["tmpf"] + 273.15
tmpb = out["tmpb"] + 273.15

Expand Down Expand Up @@ -2597,7 +2478,7 @@ def calibration_double_ended(

out.update(params)

for key, da in param_covs.items():
for key, da in param_covs.data_vars.items():
out[key + "_var"] = da

self.update(out)
Expand Down
197 changes: 163 additions & 34 deletions src/dtscalibration/datastore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,42 +528,171 @@ def get_netcdf_encoding(
return encoding


def get_params_from_pval(ip, p_val, coords):
assert len(p_val) == ip.npar, "Length of p_val is incorrect"

params = xr.Dataset(coords=coords)

# save estimates and variances to datastore, skip covariances
params["gamma"] = (tuple(), p_val[ip.gamma].item())
params["alpha"] = (("x",), p_val[ip.alpha])
params["df"] = (("time",), p_val[ip.df])
params["db"] = (("time",), p_val[ip.db])

if ip.nta:
params["talpha_fw"] = (
("time", "trans_att"),
p_val[ip.taf].reshape((ip.nt, ip.nta), order="C"),
def get_params_from_pval(ip, coords, p_val=None, p_cov=None):
if p_val is not None:
assert len(p_val) == ip.npar, "Length of p_val is incorrect"

params = xr.Dataset(coords=coords)

# save estimates and variances to datastore, skip covariances
params["gamma"] = (tuple(), p_val[ip.gamma].item())
params["alpha"] = (("x",), p_val[ip.alpha])
params["df"] = (("time",), p_val[ip.df])
params["db"] = (("time",), p_val[ip.db])

if ip.nta:
params["talpha_fw"] = (
("time", "trans_att"),
p_val[ip.taf].reshape((ip.nt, ip.nta), order="C"),
)
params["talpha_bw"] = (
("time", "trans_att"),
p_val[ip.tab].reshape((ip.nt, ip.nta), order="C"),
)
else:
params["talpha_fw"] = (("time", "trans_att"), np.zeros((ip.nt, 0)))
params["talpha_bw"] = (("time", "trans_att"), np.zeros((ip.nt, 0)))

params["talpha_fw_full"] = (
("x", "time"),
ip.get_taf_values(
pval=p_val,
x=params.x.values,
trans_att=params.trans_att.values,
axis="",
),
)
params["talpha_bw"] = (
("time", "trans_att"),
p_val[ip.tab].reshape((ip.nt, ip.nta), order="C"),
params["talpha_bw_full"] = (
("x", "time"),
ip.get_tab_values(
pval=p_val,
x=params.x.values,
trans_att=params.trans_att.values,
axis="",
),
)
else:
params["talpha_fw"] = (("time", "trans_att"), np.zeros((ip.nt, 0)))
params["talpha_bw"] = (("time", "trans_att"), np.zeros((ip.nt, 0)))

params["talpha_fw_full"] = (
("x", "time"),
ip.get_taf_values(
pval=p_val, x=params.x.values, trans_att=params.trans_att.values, axis=""
),
)
params["talpha_bw_full"] = (
("x", "time"),
ip.get_tab_values(
pval=p_val, x=params.x.values, trans_att=params.trans_att.values, axis=""
),
)
if p_cov is not None:
assert p_cov.shape == (ip.npar, ip.npar), "Shape of p_cov is incorrect"

# extract covariances and ensure broadcastable to (nx, nt)
params["gamma_df"] = (("time",), p_cov[np.ix_(ip.gamma, ip.df)][0])
params["gamma_db"] = (("time",), p_cov[np.ix_(ip.gamma, ip.db)][0])
params["gamma_alpha"] = (("x",), p_cov[np.ix_(ip.alpha, ip.gamma)][:, 0])
params["df_db"] = (
("time",),
p_cov[ip.df, ip.db],
)
params["alpha_df"] = (
(
"x",
"time",
),
p_cov[np.ix_(ip.alpha, ip.df)],
)
params["alpha_db"] = (
(
"x",
"time",
),
p_cov[np.ix_(ip.alpha, ip.db)],
)
params["tafw_gamma"] = (
(
"x",
"time",
),
ip.get_taf_values(
pval=p_cov[ip.gamma],
x=params["x"].values,
trans_att=params["trans_att"].values,
axis="",
),
)
params["tabw_gamma"] = (
(
"x",
"time",
),
ip.get_tab_values(
pval=p_cov[ip.gamma],
x=params["x"].values,
trans_att=params["trans_att"].values,
axis="",
),
)
params["tafw_alpha"] = (
(
"x",
"time",
),
ip.get_taf_values(
pval=p_cov[ip.alpha],
x=params["x"].values,
trans_att=params["trans_att"].values,
axis="x",
),
)
params["tabw_alpha"] = (
(
"x",
"time",
),
ip.get_tab_values(
pval=p_cov[ip.alpha],
x=params["x"].values,
trans_att=params["trans_att"].values,
axis="x",
),
)
params["tafw_df"] = (
(
"x",
"time",
),
ip.get_taf_values(
pval=p_cov[ip.df],
x=params["x"].values,
trans_att=params["trans_att"].values,
axis="time",
),
)
params["tafw_db"] = (
(
"x",
"time",
),
ip.get_taf_values(
pval=p_cov[ip.db],
x=params["x"].values,
trans_att=params["trans_att"].values,
axis="time",
),
)
params["tabw_db"] = (
(
"x",
"time",
),
ip.get_tab_values(
pval=p_cov[ip.db],
x=params["x"].values,
trans_att=params["trans_att"].values,
axis="time",
),
)
params["tabw_df"] = (
(
"x",
"time",
),
ip.get_tab_values(
pval=p_cov[ip.df],
x=params["x"].values,
trans_att=params["trans_att"].values,
axis="time",
),
)
# sigma2_tafw_tabw
return params


Expand Down

0 comments on commit e884f94

Please sign in to comment.