Skip to content

Commit

Permalink
Merge pull request #20 from TUW-GEO/pre_post_process
Browse files Browse the repository at this point in the history
Pre post process
  • Loading branch information
wpreimes authored Dec 14, 2023
2 parents c6112e2 + 0c94d8c commit da44497
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 55 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.8', '3.9', '3.10', '3.11']
os: ["ubuntu-latest"]
include:
- os: "windows-latest"
python-version: '3.10'
python-version: '3.11'
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
submodules: true
fetch-depth: 0
- uses: conda-incubator/setup-miniconda@v2
- uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: "latest"
auto-update-conda: true
Expand Down
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ Unreleased changes in master branch

-

Version 0.10
============

- Ts2Img module was rebuilt. Allows conversion of time series with NN lookup.
- Added example notebook for converting ASCAT time series into regularly gridded images.
- Added a simple parallelization framework, with logging and error handling.
- Added the option to pass custom pre- and post-processing functions to ts2img.

Version 0.9
===========

Expand Down
6 changes: 4 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ dependencies:
- xarray
- pip
# optional, for docs and testing
- nb_conda
#- nb_conda
- matplotlib
- ipykernel
- pip:
- pygeogrids
- pynetcf>=0.5.0
- more_itertools
- sphinx_rtd_theme
- smecv_grid
- tqdm
# Optional, for documentation and testing
- nbconvert
- sphinx_rtd_theme
- yapf
- pytest
- pytest-cov
85 changes: 50 additions & 35 deletions src/repurpose/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def parallel_process_async(
n_proc=1,
show_progress_bars=True,
ignore_errors=False,
activate_logging=True,
log_path=None,
loglevel="WARNING",
verbose=False,
progress_bar_label="Processed"
):
"""
Applies the passed function to all elements of the passed iterables.
Expand Down Expand Up @@ -83,47 +85,58 @@ def parallel_process_async(
this case the return values are kept in order.
show_progress_bars: bool, optional (default: True)
Show how many iterables were processed already.
ignore_errors: bool, optional (default: False)
If True, exceptions are caught and logged. If False, exceptions are
raised.
activate_logging: bool, optional (default: True)
If False, no logging is done at all (neither to file nor to stdout).
log_path: str, optional (default: None)
If provided, a log file is created in the passed directory.
loglevel: str, optional (default: "WARNING")
Log level to use for logging. Must be one of
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"].
verbose: float, optional (default: False)
Print all logging messages to stdout, useful for debugging.
progress_bar_label: str, optional (default: "Processed")
Label to use for the progress bar.
Returns
-------
results: list
List of return values from each function call
"""
logger = logging.getLogger()
streamHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
streamHandler.setFormatter(formatter)

if STATIC_KWARGS is None:
STATIC_KWARGS = dict()

if verbose:
logger.setLevel('DEBUG')
logger.addHandler(streamHandler)

if log_path is not None:
log_file = os.path.join(
log_path,
f"{FUNC.__name__}_{datetime.now().strftime('%Y%m%d%H%M')}.log")
if activate_logging:
logger = logging.getLogger()
streamHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
streamHandler.setFormatter(formatter)

if STATIC_KWARGS is None:
STATIC_KWARGS = dict()

if verbose:
logger.setLevel('DEBUG')
logger.addHandler(streamHandler)

if log_path is not None:
log_file = os.path.join(
log_path,
f"{FUNC.__name__}_{datetime.now().strftime('%Y%m%d%H%M')}.log")
else:
log_file = None


if log_file:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
logging.basicConfig(
filename=log_file,
level=loglevel.upper(),
format="%(levelname)s %(asctime)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
else:
log_file = None

if log_file:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
logging.basicConfig(
filename=log_file,
level=loglevel.upper(),
format="%(levelname)s %(asctime)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = None

n = np.array([len(v) for k, v in ITER_KWARGS.items()])
if len(n) == 0:
Expand All @@ -150,7 +163,7 @@ def parallel_process_async(
process_kwargs.append(kws)

if show_progress_bars:
pbar = tqdm(total=len(process_kwargs), desc=f"Processed")
pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label)
else:
pbar = None

Expand All @@ -163,7 +176,8 @@ def update(r) -> None:
pbar.update()

def error(e) -> None:
logging.error(e)
if logger is not None:
logging.error(e)
if not ignore_errors:
raise e
if pbar is not None:
Expand Down Expand Up @@ -191,12 +205,13 @@ def error(e) -> None:
if pbar is not None:
pbar.close()

if verbose:
logger.handlers.clear()
if logger is not None:
if verbose:
logger.handlers.clear()

handlers = logger.handlers[:]
for handler in handlers:
logger.removeHandler(handler)
handler.close()
handlers = logger.handlers[:]
for handler in handlers:
logger.removeHandler(handler)
handler.close()

return results
74 changes: 64 additions & 10 deletions src/repurpose/ts2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
- add possibility to use resampling methods other than nearest neighbour
- integrate repurpose.resample module
- allows weighting functions etc.
- add preprocessing and postprocessing keywords to change the input ts and
output stack before writing
- similar to resample, use multiple neighbours when available for image pixel
- further harmonisation with pynetcf interface
- time ranges for images instead of time stamps
Expand All @@ -32,7 +30,9 @@ def _convert(converter: 'Ts2Img',
writer: Regular3dimImageStack,
img_gpis: np.ndarray,
lons: np.ndarray,
lats: np.ndarray) -> xr.Dataset:
lats: np.ndarray,
preprocess_func=None,
preprocess_kwargs=None) -> xr.Dataset:
"""
Wrapper to allow parallelization of the conversion process.
This is kept outside the Ts2Img class for parallelization.
Expand All @@ -41,6 +41,9 @@ def _convert(converter: 'Ts2Img',
ts = converter._read_nn(lon, lat)
if ts is None:
continue
if preprocess_func is not None:
preprocess_kwargs = preprocess_kwargs or {}
ts = preprocess_func(ts, **preprocess_kwargs)
if np.any(np.isin(ts.columns, Ts2Img._protected_vars)):
raise ValueError(
f"Time series contains protected variables. "
Expand Down Expand Up @@ -199,7 +202,8 @@ def _read_nn(self, lon: float, lat: float) -> Union[pd.DataFrame, None]:
ts = ts.rename(columns=self.variables)[self.variables.values()]
return ts

def _calc_chunk(self, timestamps, log_path=None, n_proc=1):
def _calc_chunk(self, timestamps, preprocess_func=None, preprocess_kwargs=None,
log_path=None, n_proc=1):
"""
Create image stack from time series for the passed timestamps.
See: self.calc
Expand All @@ -209,7 +213,11 @@ def _calc_chunk(self, timestamps, log_path=None, n_proc=1):
f"{timestamps[-1]}")

# Transfer time series to images, parallel for cells
STATIC_KWARGS = {'converter': self}
STATIC_KWARGS = {
'converter': self,
'preprocess_func': preprocess_func,
'preprocess_kwargs': preprocess_kwargs,
}
ITER_KWARGS = {'writer': [], 'img_gpis': [], 'lons': [], 'lats': []}

for cell in np.unique(self.img_grid.activearrcell):
Expand All @@ -232,10 +240,12 @@ def _calc_chunk(self, timestamps, log_path=None, n_proc=1):
lon=stack['lon']))
return stack

def calc(self, path_out, format_out='slice',
fn_template="{datetime}.nc", drop_empty=False, encoding=None,
zlib=True, glob_attrs=None, var_attrs=None,
var_fillvalues=None, var_dtypes=None, img_buffer=100, n_proc=1):
def calc(self, path_out, format_out='slice', preprocess=None,
preprocess_kwargs=None, postprocess=None, postprocess_kwargs=None,
fn_template="{datetime}.nc",
drop_empty=False, encoding=None, zlib=True, glob_attrs=None,
var_attrs=None, var_fillvalues=None, var_dtypes=None,
img_buffer=100, n_proc=1):
"""
Perform conversion of all time series to images. This will first split
timestamps into processing chunks (img_buffer) and then - for each
Expand All @@ -253,6 +263,43 @@ def calc(self, path_out, format_out='slice',
- stack: write all time steps into one file. In this case if there
is a {datetime} placeholder in the fn_template, then the time
range is inserted.
preprocess: callable, optional (default: None)
Function that is applied to each time series before converting it.
The first argument is the data frame that the reader returns.
Additional keyword arguments can be passed via `preprocess_kwargs`.
The function must return a data frame of the same form as the input
data, i.e. with a datetime index and at least one column of data.
Note: As an alternative to a preprocessing function, consider
applying an adapter to the reader class. Adapters also perform
preprocessing, see `pytesmo.validation_framework.adapters`
A simple example for a preprocessing function to compute the sum:
```
def preprocess_add(df: pd.DataFrame, **preprocess_kwargs) \
-> pd.DataFrame:
df['var3'] = df['var1'] + df['var2']
return df
```
preprocess_kwargs: dict, optional (default: None)
Keyword arguments for the preprocess function. If None are given,
then the preprocessing function is is called with only the input
data frame and no additional arguments (see example above).
postprocess: Callable, optional (default: None)
Function that is applied to the image stack after loading the data
and before writing it to disk. The function must take xarray
Dataset as the first argument and return an xarray Dataset of the
same form as the input data.
A simple example for a preprocessing function to add a new variable
from the sum of two existing variables:
```
def preprocess_add(stack: xr.Dataset, **postprocess_kwargs) \
-> xr.Dataset
stack = stack.assign(var3=lambda x: x['var0'] + x['var2'])
return stack
```
postprocess_kwargs: dict, optional (default: None)
Keyword arguments for the postprocess function. If None are given,
then the postprocess function is called with only the input
image stack and no additional arguments (see example above).
fn_template: str, optional (default: "{datetime}.nc")
Template for the output image file names.
If format_out is 'slice', then a placeholder {datetime} must be
Expand Down Expand Up @@ -298,6 +345,7 @@ def calc(self, path_out, format_out='slice',
img_buffer: int, optional (default: 100)
Size of the stack before writing to disk. Larger stacks need
more memory but will lead to faster conversion.
Passing -1 means that the whole stack loaded into memory at once.
n_proc: int, optional (default: 1)
Number of processes to use for parallel processing. We parallelize
by 5 deg. grid cell.
Expand All @@ -315,7 +363,9 @@ def calc(self, path_out, format_out='slice',
dt_index_chunks = list(idx_chunks(self.timestamps, int(img_buffer)))

for timestamps in dt_index_chunks:
self.stack = self._calc_chunk(timestamps, log_path, n_proc)
self.stack = self._calc_chunk(timestamps,
preprocess, preprocess_kwargs,
log_path, n_proc)

if drop_empty:
vars = [var for var in self.stack.data_vars if var not in
Expand All @@ -331,6 +381,10 @@ def calc(self, path_out, format_out='slice',

self.stack = self.stack.drop_isel(time=idx_empty)

if postprocess is not None:
postprocess_kwargs = postprocess_kwargs or {}
self.stack = postprocess(self.stack, **postprocess_kwargs)

if var_fillvalues is not None:
for var, fillvalue in var_fillvalues.items():
self.stack[var].values = np.nan_to_num(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_add_3d_data_via_ts(self):
warnings.simplefilter("ignore",
category=pd.errors.PerformanceWarning)
ts = pd.DataFrame(
index=self.img_timestamps + self.timeoffsets,
index=pd.DatetimeIndex(self.img_timestamps + self.timeoffsets),
data={'var1': np.arange(1.1, 6.1).astype('float32'),
'var2': np.arange(11, 16).astype('int8')}
)
Expand Down
Loading

0 comments on commit da44497

Please sign in to comment.