Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check uniqueness of input paths and obs_ids in merge tool #2611

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changes/2611.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
The ``ctapipe-merge`` tool now checks for duplicated input files and
raises an error in that case.

The ``HDF5Merger`` class, and thus also the ``ctapipe-merge`` tool,
now checks for duplicated obs_ids during merging, to prevent
invalid output files.
8 changes: 8 additions & 0 deletions src/ctapipe/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,14 @@ def proton_train_clf(model_tmp_path, energy_regressor_path):
],
raises=True,
)

# modify obs_ids by adding a constant, this enables merging gamma and proton files
# which is used in the merge tool tests.
with tables.open_file(outpath, mode="r+") as f:
for table in f.walk_nodes("/", "Table"):
if "obs_id" in table.colnames:
obs_id = table.col("obs_id")
table.modify_column(colname="obs_id", column=obs_id + 1_000_000_000)
return outpath


Expand Down
30 changes: 30 additions & 0 deletions src/ctapipe/io/hdf5merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def __init__(self, output_path=None, **kwargs):
self.data_model_version = None
self.subarray = None
self.meta = None
self._merged_obs_ids = set()

# output file existed, so read subarray and data model version to make sure
# any file given matches what we already have
if appending:
Expand All @@ -202,6 +204,9 @@ def __init__(self, output_path=None, **kwargs):
)
self.required_nodes = _get_required_nodes(self.h5file)

# this will update _merged_obs_ids from existing input file
self._check_obs_ids(self.h5file)

def __call__(self, other: str | Path | tables.File):
"""
Append file ``other`` to the output file
Expand Down Expand Up @@ -267,7 +272,32 @@ def _check_can_merge(self, other):
f"Required node {node_path} not found in {other.filename}"
)

def _check_obs_ids(self, other):
keys = [
"/configuration/observation/observation_block",
"/dl1/event/subarray/trigger",
]

for key in keys:
if key in other.root:
obs_ids = other.root[key].col("obs_id")
break
else:
raise CannotMerge(
f"Input file {other.filename} is missing keys required to"
f" check for duplicated obs_ids. Tried: {keys}"
)

duplicated = self._merged_obs_ids.intersection(obs_ids)
if len(duplicated) > 0:
msg = f"Input file {other.filename} contains obs_ids already included in output file: {duplicated}"
raise CannotMerge(msg)

self._merged_obs_ids.update(obs_ids)

def _append(self, other):
self._check_obs_ids(other)

# Configuration
self._append_subarray(other)

Expand Down
27 changes: 26 additions & 1 deletion src/ctapipe/io/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_simple(tmp_path, gamma_train_clf, proton_train_clf):
merger(proton_train_clf)

subarray = SubarrayDescription.from_hdf(gamma_train_clf)
assert subarray == SubarrayDescription.from_hdf(output), "Subarays do not match"
assert subarray == SubarrayDescription.from_hdf(output), "Subarrays do not match"

tel_groups = [
"/dl1/event/telescope/parameters",
Expand Down Expand Up @@ -164,3 +164,28 @@ def test_muon(tmp_path, dl1_muon_output_file):
n_input = len(input_table)
assert len(table) == n_input
assert_table_equal(table, input_table)


def test_duplicated_obs_ids(tmp_path, dl2_shower_geometry_file):
from ctapipe.io.hdf5merger import CannotMerge, HDF5Merger

output = tmp_path / "invalid.dl1.h5"

# check for fresh file
with HDF5Merger(output) as merger:
merger(dl2_shower_geometry_file)

with pytest.raises(
CannotMerge, match="Input file .* contains obs_ids already included"
):
merger(dl2_shower_geometry_file)

# check for appending
with HDF5Merger(output, overwrite=True) as merger:
merger(dl2_shower_geometry_file)

with HDF5Merger(output, append=True) as merger:
with pytest.raises(
CannotMerge, match="Input file .* contains obs_ids already included"
):
merger(dl2_shower_geometry_file)
8 changes: 8 additions & 0 deletions src/ctapipe/tools/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import sys
from argparse import ArgumentParser
from collections import Counter
from pathlib import Path

from tqdm.auto import tqdm
Expand Down Expand Up @@ -161,6 +162,13 @@ def setup(self):
)
sys.exit(1)

counts = Counter(self.input_files)
duplicated = [p for p, c in counts.items() if c > 1]
if len(duplicated) > 0:
raise ToolConfigurationError(
f"Same file given multiple times. Duplicated files are: {duplicated}"
)

self.merger = self.enter_context(HDF5Merger(parent=self))
if self.merger.output_path in self.input_files:
raise ToolConfigurationError(
Expand Down
28 changes: 23 additions & 5 deletions src/ctapipe/tools/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from pathlib import Path

import numpy as np
import pytest
import tables
from astropy.table import vstack
from astropy.utils.diff import report_diff_values

from ctapipe.core import run_tool
from ctapipe.core import ToolConfigurationError, run_tool
from ctapipe.io import TableLoader
from ctapipe.io.astropy_helpers import read_table
from ctapipe.io.tests.test_astropy_helpers import assert_table_equal
Expand Down Expand Up @@ -176,7 +177,6 @@ def test_muon(tmp_path, dl1_muon_output_file):
argv=[
f"--output={output}",
str(dl1_muon_output_file),
str(dl1_muon_output_file),
],
raises=True,
)
Expand All @@ -185,6 +185,24 @@ def test_muon(tmp_path, dl1_muon_output_file):
input_table = read_table(dl1_muon_output_file, "/dl1/event/telescope/muon/tel_001")

n_input = len(input_table)
assert len(table) == 2 * n_input
assert_table_equal(table[:n_input], input_table)
assert_table_equal(table[n_input:], input_table)
assert len(table) == n_input
assert_table_equal(table, input_table)


def test_duplicated(tmp_path, dl1_file, dl1_proton_file):
from ctapipe.tools.merge import MergeTool

output = tmp_path / "invalid.dl1.h5"
with pytest.raises(ToolConfigurationError, match="Same file given multiple times"):
run_tool(
MergeTool(),
argv=[
str(dl1_file),
str(dl1_proton_file),
str(dl1_file),
f"--output={output}",
"--overwrite",
],
cwd=tmp_path,
raises=True,
)