Skip to content

Commit

Permalink
Allow configuring filter type and order
Browse files Browse the repository at this point in the history
We don't always want to be restricted to a 4th-order Butterworth
highpass filter, so this opens up a few more options. It is of course
always possible to use an arbitrary filter, but that is less
accessible.

Closes #6.
  • Loading branch information
angus-g committed Jul 15, 2022
1 parent 4ef6498 commit 5685e90
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 20 deletions.
6 changes: 5 additions & 1 deletion docs/filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ obtained by :py:func:`scipy.signal.butter`).
There are other filters available in the :py:mod:`filtering.filter`
module, such as one that performs the filtering in frequency space
(may give a sharper cutoff, at the expense of possible ringing), or
that allows variation of the cutoff frequency over the domain.
that allows variation of the cutoff frequency over the
domain. Alternatively, the default
:py:class:`~filtering.filter.Filter` can be constructed with a
different order, and as a highpass, lowpass, or bandpass filter,
depending on the required application.

If an alternate filter is constructed, it can be attached to the
:py:class:`~filtering.filtering.LagrangeFilter`::
Expand Down
34 changes: 22 additions & 12 deletions filtering/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@ class Filter(object):
applying the filter to advected particle data.
Args:
frequency (float): The high-pass cutoff frequency of the filter
in [/s].
frequency (Union[float, Tuple[float, float]]): The low-pass or high-pass cutoff
frequency of the filter in [/s], or a pair denoting the band to pass.
fs (float): The sampling frequency of the data over which the
filter is applied in [s].
**kwargs (Optional): Additional arguments are passed to the
:func:`~create_filter` method.
"""

def __init__(self, frequency, fs):
self._filter = Filter.create_filter(frequency, fs)
def __init__(self, frequency, fs, **kwargs):
self._filter = Filter.create_filter(frequency, fs, **kwargs)

@staticmethod
def create_filter(frequency, fs):
def create_filter(frequency, fs, order=4, filter_type="highpass"):
"""Create a filter.
This creates an analogue Butterworth filter with the given
Expand All @@ -40,10 +42,13 @@ def create_filter(frequency, fs):
frequency (float): The high-pass angular cutoff frequency of the filter
in [/s].
fs (float): The sampling frequency of the data in [s].
order (Optional[int]): The filter order, default 4.
filter_type (Optional[str]): The type of filter, one of ("highpass",
"bandpass", "lowpass"), defaults to "highpass".
"""

return signal.butter(4, frequency, "highpass", fs=fs, output="sos")
return signal.butter(order, frequency, filter_type, fs=fs, output="sos")

@staticmethod
def pad_window(x, centre_index, min_window):
Expand Down Expand Up @@ -175,18 +180,19 @@ class SpatialFilter(Filter):
Args:
frequencies (numpy.ndarray): An array with the same number
of elements as seeded particles, containing the cutoff
frequency to be used for each particle.
cutoff frequency at that location in [/s].
frequency to be used for each particle, in [/s].
fs (float): The sampling frequency of the data over which the
filter is applied in [s].
**kwargs (Optional): Additional arguments are passed to the
:func:`~create_filter` method.
"""

def __init__(self, frequencies, fs):
self._filter = SpatialFilter.create_filter(frequencies, fs)
def __init__(self, frequencies, fs, **kwargs):
self._filter = SpatialFilter.create_filter(frequencies, fs, **kwargs)

@staticmethod
def create_filter(frequencies, fs):
def create_filter(frequencies, fs, order=4, filter_type="highpass"):
"""Create a series of filters.
This creates an analogue Butterworth filter with the given
Expand All @@ -196,10 +202,14 @@ def create_filter(frequencies, fs):
frequencies (numpy.ndarray): The high-pass cutoff frequencies of the filters
in [/s].
fs (float): The sampling frequency of the data in [s].
order (Optional[int]): The filter order, default 4.
filter_type (Optional[str]): The type of filter, one of ("highpass",
"lowpass"), defaults to "highpass". Note that bandpass spatial filters
aren't supported.
"""

return sosfilt.butter(4, frequencies, "highpass", fs=fs, output="sos")
return sosfilt.butter(order, frequencies, filter_type, fs=fs, output="sos")

def apply_filter(self, data, time_index, min_window=None):
"""Apply the filter to an array of data."""
Expand Down
49 changes: 42 additions & 7 deletions test/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
import filtering


@pytest.fixture
def lats_grid():
lons = np.array([0])
lats = np.array([1, 2])

lons, lats = np.meshgrid(lons, lats)

return lats


def test_frequency_filter(leewave_data):
"""Test creation and application of frequency-space step filter."""

Expand All @@ -32,17 +42,42 @@ def test_frequency_filter(leewave_data):
assert np.all((leewave_data.U_orig.data - U_filt[0, :]) ** 2 < 3e-8)


def test_spatial_filter():
def test_spatial_filter(lats_grid):
"""Test creation and frequency response of a latitude-dependent filter."""

lons = np.array([0])
lats = np.array([1, 2])

lons, lats = np.meshgrid(lons, lats)

f = lats * 0.1
f = lats_grid * 0.1
filt = filtering.filter.SpatialFilter(f.flatten(), 1)

for freq, filter_obj in zip(f, filt._filter):
w, h = signal.sosfreqz(filter_obj)
assert np.all(abs(h)[w < freq] < 0.1)


@pytest.mark.parametrize("order", [3, 4])
@pytest.mark.parametrize(
"filter_type,freq",
[("highpass", 1e-4), ("bandpass", (1e-4, 1e-2)), ("lowpass", 1e-2)],
)
def test_create_filter(order, filter_type, freq):
"""Test parameters for filter creation."""

filt = filtering.filter.Filter(freq, 1, order=order, filter_type=filter_type)


@pytest.mark.parametrize("order", [3, 4])
@pytest.mark.parametrize("filter_type", ["highpass", "lowpass"])
def test_create_spatial_filter(lats_grid, order, filter_type):
"""Test parameters for spatial filter creation."""

f = lats_grid * 0.1
filt = filtering.filter.SpatialFilter(
f.flatten(), 1, order=order, filter_type=filter_type
)


def test_create_bandpass_spatial_filter(lats_grid):
"""Expect a failure for creating a bandpass spatial filter."""

f = lats_grid * (0.1, 0.2)
with pytest.raises(NotImplementedError):
filt = filtering.filter.SpatialFilter(f.flatten(), 1, filter_type="bandpass")

0 comments on commit 5685e90

Please sign in to comment.