Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
aggregated stats of the second pass we simply overwritten rather than append and vstacked()

unit tests are improved to properly test for the correct number of chunks
  • Loading branch information
TjarkMiener committed Aug 28, 2024
1 parent 3bae046 commit aa5666a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 40 deletions.
58 changes: 34 additions & 24 deletions src/ctapipe/monitoring/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import numpy as np
from astropy.table import Table
from astropy.table import Table, vstack

from ctapipe.core import TelescopeComponent
from ctapipe.core.traits import (
Expand Down Expand Up @@ -225,10 +225,15 @@ def second_pass(
raise ValueError(
"chunk_shift must be set if second pass over the data is requested"
)
# Check if at least one chunk is faulty
if np.all(valid_chunks):
raise ValueError(
"All chunks are valid. The second pass over the data is redundant."
)
# Get the aggregator
aggregator = self.stats_aggregator[self.stats_aggregator_type.tel[tel_id]]
# Conduct a second pass over the data
aggregated_stats_secondpass = None
aggregated_stats_secondpass = []
faulty_chunks_indices = np.where(~valid_chunks)[0]
for index in faulty_chunks_indices:
# Log information of the faulty chunks
Expand All @@ -254,31 +259,36 @@ def second_pass(
# Run the stats aggregator on the sliced dl1 table with a chunk_shift
# to sample the period of trouble (carflashes etc.) as effectively as possible.
# Checking for the length of the sliced table to be greater than the ``chunk_size``
# since it can be smaller if the last two chunks are faulty.
# since it can be smaller if the last two chunks are faulty. Note: The two last chunks
# can be overlapping during the first pass, so we simply ignore them if there are faulty.
if len(table_sliced) > aggregator.chunk_size:
aggregated_stats_secondpass = aggregator(
table=table_sliced,
masked_pixels_of_sample=masked_pixels_of_sample,
col_name=col_name,
chunk_shift=self.chunk_shift,
)
# Detect faulty pixels with multiple instances of OutlierDetector of the second pass
outlier_mask_secondpass = np.zeros_like(
aggregated_stats_secondpass["mean"], dtype=bool
)
for (
aggregated_val,
outlier_detector,
) in self.outlier_detectors.items():
outlier_mask_secondpass = np.logical_or(
outlier_mask_secondpass,
outlier_detector(aggregated_stats_secondpass[aggregated_val]),
aggregated_stats_secondpass.append(
aggregator(
table=table_sliced,
masked_pixels_of_sample=masked_pixels_of_sample,
col_name=col_name,
chunk_shift=self.chunk_shift,
)
)
# Add the outlier mask to the aggregated statistics
aggregated_stats_secondpass["outlier_mask"] = outlier_mask_secondpass
aggregated_stats_secondpass["is_valid"] = self._get_valid_chunks(
outlier_mask_secondpass
# Stack the aggregated statistics of each faulty chunk
aggregated_stats_secondpass = vstack(aggregated_stats_secondpass)
# Detect faulty pixels with multiple instances of OutlierDetector of the second pass
outlier_mask_secondpass = np.zeros_like(
aggregated_stats_secondpass["mean"], dtype=bool
)
for (
aggregated_val,
outlier_detector,
) in self.outlier_detectors.items():
outlier_mask_secondpass = np.logical_or(
outlier_mask_secondpass,
outlier_detector(aggregated_stats_secondpass[aggregated_val]),
)
# Add the outlier mask to the aggregated statistics
aggregated_stats_secondpass["outlier_mask"] = outlier_mask_secondpass
aggregated_stats_secondpass["is_valid"] = self._get_valid_chunks(
outlier_mask_secondpass
)
return aggregated_stats_secondpass

def _get_valid_chunks(self, outlier_mask):
Expand Down
53 changes: 37 additions & 16 deletions src/ctapipe/monitoring/tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,26 @@ def test_statistics_calculator(example_subarray):
"""test basic functionality of the StatisticsCalculator"""

# Create dummy data for testing
n_images = 5050
times = Time(
np.linspace(60117.911, 60117.9258, num=5000), scale="tai", format="mjd"
np.linspace(60117.911, 60117.9258, num=n_images), scale="tai", format="mjd"
)
event_ids = np.linspace(35, 725000, num=5000, dtype=int)
event_ids = np.linspace(35, 725000, num=n_images, dtype=int)
rng = np.random.default_rng(0)
charge_data = rng.normal(77.0, 10.0, size=(5000, 2, 1855))
charge_data = rng.normal(77.0, 10.0, size=(n_images, 2, 1855))
# Create tables
charge_table = Table(
[times, event_ids, charge_data],
names=("time_mono", "event_id", "image"),
)
# Initialize the aggregator and calculator
aggregator = PlainAggregator(subarray=example_subarray, chunk_size=1000)
chunk_size = 1000
aggregator = PlainAggregator(subarray=example_subarray, chunk_size=chunk_size)
chunk_shift = 500
calculator = StatisticsCalculator(
subarray=example_subarray,
stats_aggregator=aggregator,
chunk_shift=100,
chunk_shift=chunk_shift,
)
# Compute the statistical values
stats = calculator.first_pass(table=charge_table, tel_id=1)
Expand All @@ -42,9 +45,12 @@ def test_statistics_calculator(example_subarray):
table=charge_table, valid_chunks=valid_chunks, tel_id=1
)
# Stack the statistic values from the first and second pass
stats_combined = vstack([stats, stats_chunk_shift])
# Sort the combined aggregated statistic values by starting time
stats_combined.sort(["time_start"])
stats_stacked = vstack([stats, stats_chunk_shift])
# Sort the stacked aggregated statistic values by starting time
stats_stacked.sort(["time_start"])
print(stats)
print(stats_chunk_shift)
print(stats_stacked)
# Check if the calculated statistical values are reasonable
# for a camera with two gain channels
np.testing.assert_allclose(stats[0]["mean"], 77.0, atol=2.5)
Expand All @@ -55,7 +61,21 @@ def test_statistics_calculator(example_subarray):
np.testing.assert_allclose(stats_chunk_shift[0]["std"], 10.0, atol=2.5)
# Check if overlapping chunks of the second pass were aggregated
assert stats_chunk_shift is not None
assert len(stats_combined) > len(stats)
# Check if the number of aggregated chunks is correct
# In the first pass, the number of chunks is equal to the
# number of images divided by the chunk size plus one
# overlapping chunk at the end.
expected_len_firstpass = n_images // chunk_size + 1
assert len(stats) == expected_len_firstpass
# In the second pass, the number of chunks is equal to the
# number of images divided by the chunk shift minus the
# number of chunks in the first pass, since we set all
# chunks to be faulty.
expected_len_secondpass = (n_images // chunk_shift) - expected_len_firstpass
assert len(stats_chunk_shift) == expected_len_secondpass
# The total number of aggregated chunks is the sum of the
# number of chunks in the first and second pass.
assert len(stats_stacked) == expected_len_firstpass + expected_len_secondpass


def test_outlier_detector(example_subarray):
Expand Down Expand Up @@ -113,14 +133,15 @@ def test_outlier_detector(example_subarray):
stats_second_pass = calculator.second_pass(
table=ped_table, valid_chunks=stats_first_pass["is_valid"].data, tel_id=1
)
stats_combined = vstack([stats_first_pass, stats_second_pass])
# Sort the combined aggregated statistic values by starting time
stats_combined.sort(["time_start"])
# Stack the statistic values from the first and second pass
stats_stacked = vstack([stats_first_pass, stats_second_pass])
# Sort the stacked aggregated statistic values by starting time
stats_stacked.sort(["time_start"])
# Check if overlapping chunks of the second pass were aggregated
assert stats_second_pass is not None
assert len(stats_combined) > len(stats_second_pass)
assert len(stats_stacked) > len(stats_second_pass)
# Check if the calculated statistical values are reasonable
# for a camera with two gain channels
np.testing.assert_allclose(stats_combined[0]["mean"], 2.0, atol=2.5)
np.testing.assert_allclose(stats_combined[1]["median"], 2.0, atol=2.5)
np.testing.assert_allclose(stats_combined[0]["std"], 5.0, atol=2.5)
np.testing.assert_allclose(stats_stacked[0]["mean"], 2.0, atol=2.5)
np.testing.assert_allclose(stats_stacked[1]["median"], 2.0, atol=2.5)
np.testing.assert_allclose(stats_stacked[0]["std"], 5.0, atol=2.5)

0 comments on commit aa5666a

Please sign in to comment.