Skip to content

Commit

Permalink
Adapt CameraCalibrator and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnoe committed Jul 7, 2023
1 parent d0b1630 commit bdb7eb5
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 81 deletions.
57 changes: 32 additions & 25 deletions ctapipe/calib/camera/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from numba import float32, float64, guvectorize, int64

from ctapipe.containers import TelescopeDL1Container
from ctapipe.containers import TelescopeDL1Container, TelescopeEventContainer
from ctapipe.core import TelescopeComponent
from ctapipe.core.traits import (
BoolTelescopeParameter,
Expand Down Expand Up @@ -153,8 +153,8 @@ def __init__(
parent=self,
)

def _check_r1_empty(self, waveforms):
if waveforms is None:
def _check_r1_empty(self, r1):
if r1 is None or r1.waveform is None:
if not self._r1_empty_warn:
warnings.warn(
"Encountered an event with no R1 data. "
Expand All @@ -165,8 +165,8 @@ def _check_r1_empty(self, waveforms):
else:
return False

def _check_dl0_empty(self, waveforms):
if waveforms is None:
def _check_dl0_empty(self, dl0):
if dl0 is None or dl0.waveform is None:
if not self._dl0_empty_warn:
warnings.warn(
"Encountered an event with no DL0 data. "
Expand All @@ -177,37 +177,43 @@ def _check_dl0_empty(self, waveforms):
else:
return False

def _calibrate_dl0(self, event, tel_id):
waveforms = event.r1.tel[tel_id].waveform
selected_gain_channel = event.r1.tel[tel_id].selected_gain_channel
if self._check_r1_empty(waveforms):
def r1_to_dl0(self, tel_event: TelescopeEventContainer):
if self._check_r1_empty(tel_event.r1):
return

tel_id = tel_event.index.tel_id
waveforms = tel_event.r1.waveform

selected_gain_channel = tel_event.r1.selected_gain_channel
reduced_waveforms_mask = self.data_volume_reducer(
waveforms, tel_id=tel_id, selected_gain_channel=selected_gain_channel
waveforms,
tel_id=tel_id,
selected_gain_channel=selected_gain_channel,
)

waveforms_copy = waveforms.copy()
waveforms_copy[~reduced_waveforms_mask] = 0
event.dl0.tel[tel_id].waveform = waveforms_copy
event.dl0.tel[tel_id].selected_gain_channel = selected_gain_channel
tel_event.dl0.waveform = waveforms_copy
tel_event.dl0.selected_gain_channel = selected_gain_channel

def _calibrate_dl1(self, event, tel_id):
waveforms = event.dl0.tel[tel_id].waveform
if self._check_dl0_empty(waveforms):
def dl0_to_dl1(self, tel_event: TelescopeEventContainer):
if self._check_dl0_empty(tel_event.dl0):
return

tel_id = tel_event.index.tel_id
dl0 = tel_event.dl0
waveforms = dl0.waveform
n_pixels, n_samples = waveforms.shape

selected_gain_channel = event.dl0.tel[tel_id].selected_gain_channel
selected_gain_channel = dl0.selected_gain_channel
broken_pixels = _get_invalid_pixels(
n_pixels,
event.mon.tel[tel_id].pixel_status,
tel_event.mon.pixel_status,
selected_gain_channel,
)

dl1_calib = event.calibration.tel[tel_id].dl1
time_shift = event.calibration.tel[tel_id].dl1.time_shift
dl1_calib = tel_event.calibration.dl1
time_shift = dl1_calib.time_shift
readout = self.subarray.tel[tel_id].camera.readout

# subtract any remaining pedestal before extraction
Expand Down Expand Up @@ -267,7 +273,7 @@ def _calibrate_dl1(self, event, tel_id):
)

# store the results in the event structure
event.dl1.tel[tel_id] = dl1
tel_event.dl1 = dl1

def __call__(self, event):
"""
Expand All @@ -280,11 +286,12 @@ def __call__(self, event):
event : container
A `~ctapipe.containers.ArrayEventContainer` event container
"""
# TODO: How to handle different calibrations depending on tel_id?
tel = event.r1.tel or event.dl0.tel or event.dl1.tel
for tel_id in tel.keys():
self._calibrate_dl0(event, tel_id)
self._calibrate_dl1(event, tel_id)
for tel_event in event.tel.values():
self.calibrate_tel_event(tel_event)

def calibrate_tel_event(self, tel_event):
self.r1_to_dl0(tel_event)
self.dl0_to_dl1(tel_event)


def shift_waveforms(waveforms, time_shift_samples):
Expand Down
137 changes: 81 additions & 56 deletions ctapipe/calib/camera/tests/test_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from traitlets.config import Config

from ctapipe.calib.camera.calibrator import CameraCalibrator
from ctapipe.containers import ArrayEventContainer
from ctapipe.containers import (
ArrayEventContainer,
TelescopeDL0Container,
TelescopeDL1Container,
TelescopeEventContainer,
TelescopeEventIndexContainer,
TelescopeR1Container,
)
from ctapipe.image.extractor import (
FullWaveformSum,
GlobalPeakWindowSum,
Expand All @@ -21,11 +28,11 @@


def test_camera_calibrator(example_event, example_subarray):
tel_id = list(example_event.r0.tel)[0]
tel_event = next(iter(example_event.tel.values()))
calibrator = CameraCalibrator(subarray=example_subarray)
calibrator(example_event)
image = example_event.dl1.tel[tel_id].image
peak_time = example_event.dl1.tel[tel_id].peak_time
image = tel_event.dl1.image
peak_time = tel_event.dl1.peak_time
assert image is not None
assert peak_time is not None
assert image.shape == (1764,)
Expand Down Expand Up @@ -99,53 +106,68 @@ def test_config(example_subarray):

def test_check_r1_empty(example_event, example_subarray):
calibrator = CameraCalibrator(subarray=example_subarray)
tel_id = list(example_event.r0.tel)[0]
waveform = example_event.r1.tel[tel_id].waveform.copy()
tel_id, tel_event = next(iter(example_event.tel.items()))
waveform = tel_event.r1.waveform.copy()
with pytest.warns(UserWarning):
example_event.r1.tel[tel_id].waveform = None
calibrator._calibrate_dl0(example_event, tel_id)
assert example_event.dl0.tel[tel_id].waveform is None
tel_event.r1.waveform = None
calibrator.r1_to_dl0(tel_event)
assert tel_event.dl0.waveform is None

assert calibrator._check_r1_empty(None) is True
assert calibrator._check_r1_empty(waveform) is False
assert calibrator._check_r1_empty(TelescopeR1Container(waveform=None)) is True
assert calibrator._check_r1_empty(TelescopeR1Container(waveform=waveform)) is False

calibrator = CameraCalibrator(
subarray=example_subarray,
image_extractor=FullWaveformSum(subarray=example_subarray),
)
event = ArrayEventContainer()
event.dl0.tel[tel_id].waveform = np.full((2048, 128), 2)
event.tel[tel_id] = TelescopeEventContainer(
index=TelescopeEventIndexContainer(obs_id=1, event_id=1, tel_id=tel_id),
dl0=TelescopeDL0Container(waveform=np.full((2048, 128), 2)),
)
with pytest.warns(UserWarning):
calibrator(event)
assert (event.dl0.tel[tel_id].waveform == 2).all()
assert (event.dl1.tel[tel_id].image == 2 * 128).all()
assert (event.tel[tel_id].dl0.waveform == 2).all()
assert (event.tel[tel_id].dl1.image == 2 * 128).all()


def test_check_dl0_empty(example_event, example_subarray):
calibrator = CameraCalibrator(subarray=example_subarray)
tel_id = list(example_event.r0.tel)[0]
calibrator._calibrate_dl0(example_event, tel_id)
waveform = example_event.dl0.tel[tel_id].waveform.copy()
tel_id, tel_event = next(iter(example_event.tel.items()))

calibrator.r1_to_dl0(tel_event)
waveform = tel_event.dl0.waveform.copy()

with pytest.warns(UserWarning):
example_event.dl0.tel[tel_id].waveform = None
calibrator._calibrate_dl1(example_event, tel_id)
assert example_event.dl1.tel[tel_id].image is None
tel_event.dl0.waveform = None
calibrator.dl0_to_dl1(tel_event)
assert tel_event.dl1.image is None

assert calibrator._check_dl0_empty(None) is True
assert calibrator._check_dl0_empty(waveform) is False
assert calibrator._check_dl0_empty(TelescopeDL0Container(waveform=None)) is True
assert (
calibrator._check_dl0_empty(TelescopeDL0Container(waveform=waveform)) is False
)

calibrator = CameraCalibrator(subarray=example_subarray)
event = ArrayEventContainer()
event.dl1.tel[tel_id].image = np.full(2048, 2)
tel_event = TelescopeEventContainer(
index=TelescopeEventIndexContainer(obs_id=1, event_id=1, tel_id=tel_id),
dl1=TelescopeDL1Container(image=np.full(2048, 2)),
)
event.tel[tel_id] = tel_event
with pytest.warns(UserWarning):
calibrator(event)
assert (event.dl1.tel[tel_id].image == 2).all()
assert (tel_event.dl1.image == 2).all()


def test_dl1_charge_calib(example_subarray):
# copy because we mutate the camera, should not affect other tests
rng = np.random.default_rng(1)
tel_id = 1
subarray = deepcopy(example_subarray)
camera = subarray.tel[1].camera
camera = subarray.tel[tel_id].camera
# test with a sampling_rate different than 1 to
# test if we handle time vs. slices correctly
sampling_rate = 2
Expand All @@ -155,60 +177,63 @@ def test_dl1_charge_calib(example_subarray):
n_samples = 96
mid = n_samples // 2
pulse_sigma = 6
random = np.random.default_rng(1)
x = np.arange(n_samples)

# Randomize times and create pulses
time_offset = random.uniform(-10, +10, n_pixels)
time_offset = rng.uniform(-10, +10, n_pixels)
y = norm.pdf(x, mid + time_offset[:, np.newaxis], pulse_sigma).astype("float32")

camera.readout.reference_pulse_shape = norm.pdf(x, mid, pulse_sigma)[np.newaxis, :]
camera.readout.reference_pulse_sample_width = 1 / camera.readout.sampling_rate

# Define absolute calibration coefficients
absolute = random.uniform(100, 1000, n_pixels).astype("float32")
absolute = rng.uniform(100, 1000, n_pixels).astype("float32")
y *= absolute[:, np.newaxis]

# Define relative coefficients
relative = random.normal(1, 0.01, n_pixels)
relative = rng.normal(1, 0.01, n_pixels)
y /= relative[:, np.newaxis]

# Define pedestal
pedestal = random.uniform(-4, 4, n_pixels)
pedestal = rng.uniform(-4, 4, n_pixels)
y += pedestal[:, np.newaxis]

event = ArrayEventContainer()
tel_id = list(subarray.tel.keys())[0]
event.dl0.tel[tel_id].waveform = y
event.dl0.tel[tel_id].selected_gain_channel = np.zeros(len(y), dtype=int)
event.r1.tel[tel_id].selected_gain_channel = np.zeros(len(y), dtype=int)
event.tel[tel_id] = TelescopeEventContainer(
index=TelescopeEventIndexContainer(obs_id=1, event_id=1, tel_id=tel_id),
dl0=TelescopeDL0Container(
waveform=y,
selected_gain_channel=np.zeros(len(y), dtype=int),
),
)

# Test default
calibrator = CameraCalibrator(
subarray=subarray, image_extractor=FullWaveformSum(subarray=subarray)
subarray=subarray,
image_extractor=FullWaveformSum(subarray=subarray),
)
calibrator(event)
np.testing.assert_allclose(event.dl1.tel[tel_id].image, y.sum(1), rtol=1e-4)
np.testing.assert_allclose(event.tel[tel_id].dl1.image, y.sum(1), rtol=1e-4)

event.calibration.tel[tel_id].dl1.pedestal_offset = pedestal
event.calibration.tel[tel_id].dl1.absolute_factor = absolute
event.calibration.tel[tel_id].dl1.relative_factor = relative
event.tel[tel_id].calibration.dl1.pedestal_offset = pedestal
event.tel[tel_id].calibration.dl1.absolute_factor = absolute
event.tel[tel_id].calibration.dl1.relative_factor = relative

# Test without timing corrections
calibrator(event)
dl1 = event.dl1.tel[tel_id]
dl1 = event.tel[tel_id].dl1
np.testing.assert_allclose(dl1.image, 1, rtol=1e-5)
expected_peak_time = (mid + time_offset) / sampling_rate
np.testing.assert_allclose(dl1.peak_time, expected_peak_time, rtol=1e-5)

# test with timing corrections
event.calibration.tel[tel_id].dl1.time_shift = time_offset / sampling_rate
event.tel[tel_id].calibration.dl1.time_shift = time_offset / sampling_rate
calibrator(event)

# more rtol since shifting might lead to reduced integral
np.testing.assert_allclose(event.dl1.tel[tel_id].image, 1, rtol=1e-5)
np.testing.assert_allclose(event.tel[tel_id].dl1.image, 1, rtol=1e-5)
np.testing.assert_allclose(
event.dl1.tel[tel_id].peak_time, mid / sampling_rate, atol=1
event.tel[tel_id].dl1.peak_time, mid / sampling_rate, atol=1
)

# test not applying time shifts
Expand All @@ -217,9 +242,9 @@ def test_dl1_charge_calib(example_subarray):
calibrator.apply_waveform_time_shift = False
calibrator(event)

np.testing.assert_allclose(event.dl1.tel[tel_id].image, 1, rtol=1e-4)
np.testing.assert_allclose(event.tel[tel_id].dl1.image, 1, rtol=1e-4)
np.testing.assert_allclose(
event.dl1.tel[tel_id].peak_time, expected_peak_time, atol=1
event.tel[tel_id].dl1.peak_time, expected_peak_time, atol=1
)

# We now use GlobalPeakWindowSum to see the effect of missing charge
Expand All @@ -232,9 +257,9 @@ def test_dl1_charge_calib(example_subarray):
calibrator(event)
# test with timing corrections, should work
# higher rtol because we cannot shift perfectly
np.testing.assert_allclose(event.dl1.tel[tel_id].image, 1, rtol=0.01)
np.testing.assert_allclose(event.tel[tel_id].dl1.image, 1, rtol=0.01)
np.testing.assert_allclose(
event.dl1.tel[tel_id].peak_time, mid / sampling_rate, atol=1
event.tel[tel_id].dl1.peak_time, mid / sampling_rate, atol=1
)

# test deactivating timing corrections
Expand All @@ -243,8 +268,8 @@ def test_dl1_charge_calib(example_subarray):

# make sure we chose an example where the time shifts matter
# charges should be quite off due to summing around global shift
assert not np.allclose(event.dl1.tel[tel_id].image, 1, rtol=0.1)
assert not np.allclose(event.dl1.tel[tel_id].peak_time, mid / sampling_rate, atol=1)
assert not np.allclose(event.tel[tel_id].dl1.image, 1, rtol=0.1)
assert not np.allclose(event.tel[tel_id].dl1.peak_time, mid / sampling_rate, atol=1)


def test_shift_waveforms():
Expand Down Expand Up @@ -283,22 +308,22 @@ def test_invalid_pixels(example_event, example_subarray):
)
# going to modify this
event = deepcopy(example_event)
tel_id = list(event.r0.tel)[0]
tel_id, tel_event = next(iter(event.tel.items()))
camera = example_subarray.tel[tel_id].camera
sampling_rate = camera.readout.sampling_rate.to_value(u.GHz)

event.mon.tel[tel_id].pixel_status.flatfield_failing_pixels[:, 0] = True
event.r1.tel[tel_id].waveform.fill(0.0)
event.r1.tel[tel_id].waveform[1:, 20] = 1.0
event.r1.tel[tel_id].waveform[0, 10] = 9999
tel_event.mon.pixel_status.flatfield_failing_pixels[:, 0] = True
tel_event.r1.waveform.fill(0.0)
tel_event.r1.waveform[1:, 20] = 1.0
tel_event.r1.waveform[0, 10] = 9999

calibrator = CameraCalibrator(
subarray=example_subarray,
config=config,
)
calibrator(event)
assert np.all(event.dl1.tel[tel_id].image == 1.0)
assert np.all(event.dl1.tel[tel_id].peak_time == 20.0 / sampling_rate)
assert np.all(tel_event.dl1.image == 1.0)
assert np.all(tel_event.dl1.peak_time == 20.0 / sampling_rate)

# test we can set the invalid pixel handler to None
config.CameraCalibrator.invalid_pixel_handler_type = None
Expand All @@ -307,5 +332,5 @@ def test_invalid_pixels(example_event, example_subarray):
config=config,
)
calibrator(event)
assert event.dl1.tel[tel_id].image[0] == 9999
assert event.dl1.tel[tel_id].peak_time[0] == 10.0 / sampling_rate
assert tel_event.dl1.image[0] == 9999
assert tel_event.dl1.peak_time[0] == 10.0 / sampling_rate

0 comments on commit bdb7eb5

Please sign in to comment.