diff --git a/src/dtscalibration/datastore_utils.py b/src/dtscalibration/datastore_utils.py index 39aff46f..1e1d3a79 100644 --- a/src/dtscalibration/datastore_utils.py +++ b/src/dtscalibration/datastore_utils.py @@ -93,7 +93,7 @@ def check_timestep_allclose(ds, eps=0.01): 'for all time steps' -def merge_double_ended(ds_fw, ds_bw, cable_length, plot_result=True): +def merge_double_ended(ds_fw, ds_bw, cable_length, plot_result=True, verbose=True): """ Some measurements are not set up on the DTS-device as double-ended meausurements. This means that the two channels have to be merged manually. @@ -123,9 +123,7 @@ def merge_double_ended(ds_fw, ds_bw, cable_length, plot_result=True): and ds_bw.attrs['isDoubleEnded'] == '0'), \ "(one of the) input DataStores is already double ended" - assert (ds_fw.time.size == ds_bw.time.size), \ - "The two input DataStore objects are not of the same size in the " +\ - "time dimension." + ds_fw, ds_bw = merge_double_ended_times(ds_fw, ds_bw, verbose=verbose) ds = ds_fw.copy() ds_bw = ds_bw.copy() @@ -158,6 +156,85 @@ def merge_double_ended(ds_fw, ds_bw, cable_length, plot_result=True): return ds +def merge_double_ended_times(ds_fw, ds_bw, verbose=True): + """Helper for `merge_double_ended()` to deal with missing measurements. The + number of measurements of the forward and backward channels might get out + of sync if the device shuts down before the measurement of the last channel + is complete. This skips all measurements that are not accompanied by a partner + channel. + + Provides little protection against swapping fw and bw. + + If all measurements are recorded: fw_t0, bw_t0, fw_t1, bw_t1, fw_t2, bw_t2, .. + > all are passed + + If some are missing the accompanying measurement is skipped: + - fw_t0, bw_t0, bw_t1, fw_t2, bw_t2, .. > fw_t0, bw_t0, fw_t2, bw_t2, .. + - fw_t0, bw_t0, fw_t1, fw_t2, bw_t2, .. > fw_t0, bw_t0, fw_t2, bw_t2, .. + - fw_t0, bw_t0, bw_t1, fw_t2, fw_t3, bw_t3, .. > fw_t0, bw_t0, fw_t3, bw_t3, + + Mixing forward and backward channels can be problematic when there is a pause + after measuring all channels. This function is not perfect as the following + situation is not caught: + - fw_t0, bw_t0, fw_t1, bw_t2, fw_t3, bw_t3, .. + > fw_t0, bw_t0, fw_t1, bw_t2, fw_t3, bw_t3, .. + + This routine checks that the lowest channel + number is measured first (aka the forward channel), but it doesn't catch the + last case as it doesn't know that fw_t1 and bw_t2 belong to different cycles. + + Parameters + ---------- + ds_fw : DataSore object + DataStore object representing the forward measurement channel + ds_bw : DataSore object + DataStore object representing the backward measurement channel + + Returns + ------- + ds_fw_sel : DataSore object + DataStore object representing the forward measurement channel with + only times for which there is also a ds_bw measurement + ds_bw_sel : DataSore object + DataStore object representing the backward measurement channel with + only times for which there is also a ds_fw measurement + """ + if 'forward channel' in ds_fw.attrs and 'forward channel' in ds_bw.attrs: + assert ds_fw.attrs['forward channel'] < ds_bw.attrs['forward channel'], "ds_fw and ds_bw are swapped" + elif 'forwardMeasurementChannel' in ds_fw.attrs and 'forwardMeasurementChannel' in ds_bw.attrs: + assert ds_fw.attrs['forwardMeasurementChannel'] < ds_bw.attrs['forwardMeasurementChannel'], \ + "ds_fw and ds_bw are swapped" + + if (ds_bw.time.size == ds_fw.time.size) and np.all(ds_bw.time.values > ds_fw.time.values): + return ds_fw, ds_bw + + iuse_chfw = list() + iuse_chbw = list() + + times_fw = {k: ("fw", i) for i, k in enumerate(ds_fw.time.values)} + times_bw = {k: ("bw", i) for i, k in enumerate(ds_bw.time.values)} + times_all = dict(sorted(({**times_fw, **times_bw}).items())) + times_all_val = list(times_all.values()) + + for (direction, ind), (direction_next, ind_next) in zip(times_all_val[:-1], times_all_val[1:]): + if direction == "fw" and direction_next == "bw": + iuse_chfw.append(ind) + iuse_chbw.append(ind_next) + + elif direction == "bw" and direction_next == "fw": + pass + + elif direction == "fw" and direction_next == "fw": + if verbose: + print(f"Missing backward measurement beween {ds_fw.time.values[ind]} and {ds_fw.time.values[ind_next]}") + + elif direction == "bw" and direction_next == "bw": + if verbose: + print(f"Missing forward measurement beween {ds_bw.time.values[ind]} and {ds_bw.time.values[ind_next]}") + + return ds_fw.isel(time=iuse_chfw), ds_bw.isel(time=iuse_chbw) + + # pylint: disable=too-many-locals def shift_double_ended(ds, i_shift, verbose=True): """ diff --git a/tests/test_datastore.py b/tests/test_datastore.py index d27e01fa..8f1e3534 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -637,3 +637,65 @@ def test_merge_double_ended(): result = (ds.isel(time=0).st - ds.isel(time=0).rst).sum().values np.testing.assert_approx_equal(result, -3712866.0382, significant=10) + + +@pytest.mark.parametrize( + "inotinfw, inotinbw, inotinout", + [ + ([], [], []), + ([1], [], [1]), + ([], [1], [1]), + ([1], [1], [1]), + ([1, 2], [], [1, 2]), + ([], [1, 2], [1, 2]), + ([1], [2], [1, 2]), + pytest.param([2], [1], [1, 2], marks=pytest.mark.xfail) + ]) +def test_merge_double_ended_times(inotinfw, inotinbw, inotinout): + """ + Arguments are the indices not included in resp fw measurements, bw measurements, + and merged output. + + If all measurements are recorded: fw_t0, bw_t0, fw_t1, bw_t1, fw_t2, bw_t2, .. + > all are passed + + If some are missing the accompanying measurement is skipped: + - fw_t0, bw_t0, bw_t1, fw_t2, bw_t2, .. > fw_t0, bw_t0, fw_t2, bw_t2, .. + - fw_t0, bw_t0, fw_t1, fw_t2, bw_t2, .. > fw_t0, bw_t0, fw_t2, bw_t2, .. + - fw_t0, bw_t0, bw_t1, fw_t2, fw_t3, bw_t3, .. > fw_t0, bw_t0, fw_t3, bw_t3, + + Mixing forward and backward channels can be problematic when there is a pause + after measuring all channels. This function is not perfect as the following + situation is not caught: + - fw_t0, bw_t0, fw_t1, bw_t2, fw_t3, bw_t3, .. + > fw_t0, bw_t0, fw_t1, bw_t2, fw_t3, bw_t3, .. + + This routine checks that the lowest channel + number is measured first (aka the forward channel), but it doesn't catch the + last case as it doesn't know that fw_t1 and bw_t2 belong to different cycles. + Any ideas are welcome. + """ + filepath_fw = data_dir_double_single_ch1 + filepath_bw = data_dir_double_single_ch2 + cable_length = 2017.7 + + ds_fw = read_silixa_files(directory=filepath_fw) + ds_bw = read_silixa_files(directory=filepath_bw) + + # set stokes to varify proper time alignment + ds_fw.st.values = np.tile(np.arange(ds_fw.time.size)[None], (ds_fw.x.size, 1)) + ds_bw.st.values = np.tile(np.arange(ds_bw.time.size)[None], (ds_bw.x.size, 1)) + + # transform time index that is not included in: fw, bw, merged, to the ones included + ifw = list(i for i in range(6) if i not in inotinfw) + ibw = list(i for i in range(6) if i not in inotinbw) + iout = list(i for i in range(6) if i not in inotinout) + + ds = merge_double_ended( + ds_fw.isel(time=ifw), + ds_bw.isel(time=ibw), + cable_length=cable_length, + plot_result=False, + verbose=False) + assert ds.time.size == len(iout) and np.all(ds.st.isel(x=0) == iout), \ + f"FW:{ifw} & BW:{ibw} should lead to {iout} but instead leads to {ds.st.isel(x=0).values}"