Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TjarkMiener committed Jun 12, 2024
1 parent 5c6595b commit ba00bf3
Showing 1 changed file with 21 additions and 23 deletions.
44 changes: 21 additions & 23 deletions src/ctapipe/calib/camera/tests/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
Tests for StatisticsExtractor and related functions
"""

from astropy.table import QTable
import numpy as np
import pytest
from astropy.table import QTable

from ctapipe.calib.camera.extractor import PlainExtractor, SigmaClippingExtractor


@pytest.fixture(name="test_plainextractor")
def fixture_test_plainextractor(example_subarray):
"""test the PlainExtractor"""
return PlainExtractor(
subarray=example_subarray, chunk_size=2500
)
return PlainExtractor(subarray=example_subarray, chunk_size=2500)


@pytest.fixture(name="test_sigmaclippingextractor")
def fixture_test_sigmaclippingextractor(example_subarray):
"""test the SigmaClippingExtractor"""
return SigmaClippingExtractor(
subarray=example_subarray, chunk_size=2500
)
return SigmaClippingExtractor(subarray=example_subarray, chunk_size=2500)


def test_extractors(test_plainextractor, test_sigmaclippingextractor):
"""test basic functionality of the StatisticsExtractors"""
Expand All @@ -36,17 +36,17 @@ def test_extractors(test_plainextractor, test_sigmaclippingextractor):
dl1_table=flatfield_dl1_table
)

assert np.any(np.abs(plain_stats_list[0].mean - 2.0) > 1.5) is False
assert np.any(np.abs(sigmaclipping_stats_list[0].mean - 77.0) > 1.5) is False
assert not np.any(np.abs(plain_stats_list[0].mean - 2.0) > 1.5)
assert not np.any(np.abs(sigmaclipping_stats_list[0].mean - 77.0) > 1.5)

assert np.any(np.abs(plain_stats_list[0].mean - 2.0) > 1.5) is False
assert np.any(np.abs(sigmaclipping_stats_list[0].mean - 77.0) > 1.5) is False
assert not np.any(np.abs(plain_stats_list[0].mean - 2.0) > 1.5)
assert not np.any(np.abs(sigmaclipping_stats_list[0].mean - 77.0) > 1.5)

assert np.any(np.abs(plain_stats_list[1].median - 2.0) > 1.5) is False
assert np.any(np.abs(sigmaclipping_stats_list[1].median - 77.0) > 1.5) is False
assert not np.any(np.abs(plain_stats_list[1].median - 2.0) > 1.5)
assert not np.any(np.abs(sigmaclipping_stats_list[1].median - 77.0) > 1.5)

assert np.any(np.abs(plain_stats_list[0].std - 5.0) > 1.5) is False
assert np.any(np.abs(sigmaclipping_stats_list[0].std - 10.0) > 1.5) is False
assert not np.any(np.abs(plain_stats_list[0].std - 5.0) > 1.5)
assert not np.any(np.abs(sigmaclipping_stats_list[0].std - 10.0) > 1.5)


def test_check_outliers(test_sigmaclippingextractor):
Expand All @@ -63,11 +63,11 @@ def test_check_outliers(test_sigmaclippingextractor):
)

# check if outliers where detected correctly
assert sigmaclipping_stats_list[0].median_outliers[0][120] is True
assert sigmaclipping_stats_list[0].median_outliers[1][67] is True
assert sigmaclipping_stats_list[1].median_outliers[0][120] is True
assert sigmaclipping_stats_list[1].median_outliers[1][67] is True
assert sigmaclipping_stats_list[0].median_outliers[0][120]
assert sigmaclipping_stats_list[0].median_outliers[1][67]
assert sigmaclipping_stats_list[1].median_outliers[0][120]
assert sigmaclipping_stats_list[1].median_outliers[1][67]


def test_check_chunk_shift(test_sigmaclippingextractor):
"""test the chunk shift option and the boundary case for the last chunk"""
Expand All @@ -77,10 +77,8 @@ def test_check_chunk_shift(test_sigmaclippingextractor):
# insert outliers
flatfield_dl1_table = QTable([times, flatfield_dl1_data], names=("time", "image"))
sigmaclipping_stats_list = test_sigmaclippingextractor(
dl1_table=flatfield_dl1_table,
chunk_shift=2000
dl1_table=flatfield_dl1_table, chunk_shift=2000
)

# check if three chunks are used for the extraction
assert len(sigmaclipping_stats_list) == 3

0 comments on commit ba00bf3

Please sign in to comment.