Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDAT Migration Phase 2: Refactor arm_diags set #842

Open
wants to merge 34 commits into
base: cdat-migration-fy24
Choose a base branch
from

Conversation

chengzhuzhang
Copy link
Contributor

@chengzhuzhang chengzhuzhang commented Aug 26, 2024

Description

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • My changes generate no new warnings
  • Any dependent changes have been merged and published in downstream modules

If applicable:

  • New and existing unit tests pass with my changes (locally and CI/CD build)
  • I have added tests that prove my fix is effective or that my feature works
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have noted that this is a breaking change for a major release (fix or feature that would cause existing functionality to not work as expected)

@chengzhuzhang
Copy link
Contributor Author

Updates:
All sub sets code refactoring including (annual_cycle, diurnal_cycle(_zt), convection_onset, aerosol_activation) are completed.

Todo:
Fix mypy issues
Potential performance bottleneck from deriving variables using xarray dataset as input

@chengzhuzhang chengzhuzhang marked this pull request as ready for review September 9, 2024 17:25
@tomvothecoder
Copy link
Collaborator

tomvothecoder commented Sep 9, 2024

Potential performance bottleneck from deriving variables using xarray dataset as input

Thank you for identifying the two formulas code as potential bottlenecks. I confirmed that these computations with Xarray are indeed much slower than CDAT.

Solution

I ran a performance benchmark and found the solution to the slowness: we need to call .load(scheduler="sync") in _get_dataset_with_source_vars() to speed up the computation. I pushed commit ef44fc6 (#842) with this fix.

Benchmark Results

The first runtime is the current code and the second runtime is with .load(). I also ran e3sm_diags with commit ef44fc6 (#842) and confirmed a significant runtime improvement, similar to the benchmarks below. Performance is now on-par with CDAT.

"""
Results
----------
1. Elapsed time (Xarray non-chunked): 6.540755605790764 seconds
2. Elapsed time (Xarray non-chunked with .load()): 0.17097265785560012 seconds
3. Elapsed time (Xarray chunked): 0.1452920027077198 seconds
4. Elapsed time (numpy .values): 6.418793010059744 seconds
5. Elapsed time (numpy .data): 7.334999438840896 seconds
"""
Elapsed time (CDAT main branch, single runtime): 0.12261438369750977 seconds

@chengzhuzhang
Copy link
Contributor Author

The first runtime is the current code and the second runtime is with .load(). I also ran e3sm_diags with commit ef44fc6 (#842)

@tomvothecoder thank you for performing the timing test. Interesting that when I test your commit. Running through a configuration with each subset, it took 24 mins. Without, .load(), the total time is about 3 mins, which I consider at least on par with cdat code. In this case, maybe we should drop the .load() change?

@tomvothecoder
Copy link
Collaborator

The first runtime is the current code and the second runtime is with .load(). I also ran e3sm_diags with commit ef44fc6 (#842)

@tomvothecoder thank you for performing the timing test. Interesting that when I test your commit. Running through a configuration with each subset, it took 24 mins. Without, .load(), the total time is about 3 mins, which I consider at least on par with cdat code. In this case, maybe we should drop the .load() change?

I would think that loading the derived variables dataset into memory shouldn't slow down performance unless the datasets were extremely large (which we should use Dask chunking for).

I only benchmarked performance for the formula computations. I will benchmark a complete run to verify your findings and determine if we should revert the commit or not.

Side-note:

It could be that the logic I implemented already stores the dataset in-memory, since it merges multiple xr.Dataset objects opened via open_dataset() (uses numpy arrays) instead of using open_mfdataset() (uses Dask arrays). If this the case, I don't see how .load() would improve the speed of the formula computations though.

for var in vars_to_get:
ds = self._get_time_series_dataset_obj(var)
datasets.append(ds)
ds = xr.merge(datasets)
ds = squeeze_time_dim(ds)
ds.load(scheduler="sync")
return ds

@tomvothecoder
Copy link
Collaborator

RE: #842 (comment)
I found adding .load() in _get_dataset_with_source_vars() adds 2-4 minutes of runtime to a complete arm_diags run.

I reverted this change and will now address the pre-commit issues.

# Commit: 58361c49-4b1b-11ec-9b3b-9c5c8e2f5e4e (no .load())
# run_set function took 281.78 seconds to complete.
# run_set function took 332.79 seconds to complete.
# Commit: ef44fc6ffd538a5e257b097b99f7a1a79b79bc3b (with .load())
# run_set function took 472.81 seconds to complete.

@tomvothecoder
Copy link
Collaborator

tomvothecoder commented Sep 13, 2024

Hey @chengzhuzhang, I fixed the pre-commit issues in 2c248fb (#842). I also performed initial code cleanup since I would probably be doing that later anyways. Refer to the commit message and my review comments for more information.

I re-ran all sets and they completed successfully with these changes, although I noticed some of the diagnostic sets weren't done yet.

Copy link
Collaborator

@tomvothecoder tomvothecoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial review comment and some questions of FIXME comments.

e3sm_diags/driver/arm_diags_driver.py Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Outdated Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Outdated Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Outdated Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Outdated Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Outdated Show resolved Hide resolved
@chengzhuzhang
Copy link
Contributor Author

Thank you @tomvothecoder I will do another pass to see if there are other sets need to be refactored. Thanks a lot for fixing and clean this branch.

@chengzhuzhang
Copy link
Contributor Author

@tomvothecoder I applied png regression tests. And all figures are produced and results are as expected as noted in the notebook.
Also performance-wise, the refactored codes are similar to the original codes. wall time is ~ 10mins with 4 workers for the full arm_diags set.

@tomvothecoder
Copy link
Collaborator

Hey @chengzhuzhang, I noticed you have several other PRs you're working on right now. I'm happy to help finish up refactoring the last diag function in this PR. Let me know.

@chengzhuzhang
Copy link
Contributor Author

Hey @chengzhuzhang, I noticed you have several other PRs you're working on right now. I'm happy to help finish up refactoring the last diag function in this PR. Let me know.

@tomvothecoder it would be great if you could help me on finish this last diag, so that I'm not holding back the progress to merge!

@tomvothecoder
Copy link
Collaborator

Hey @chengzhuzhang, I noticed you have several other PRs you're working on right now. I'm happy to help finish up refactoring the last diag function in this PR. Let me know.

@tomvothecoder it would be great if you could help me on finish this last diag, so that I'm not holding back the progress to merge!

I found that no sets actually run "annual_cycle_aerosol" (arm_diags_model_vs_obs.cfg, arm_diags_model_vs_model.cfg). Unless we expect to run this diagnostic in the future, I think we can delete _run_diag_annual_cycle_aerosol instead of refactoring it.

@chengzhuzhang
Copy link
Contributor Author

@tomvothecoder good for catching this. I think for now we can just create an issue to log this problem, and we can add back the code at a later time.

@@ -152,6 +152,7 @@ def climo(dataset: xr.Dataset, var_key: str, freq: ClimoFreq):
# averaging.
dims = [dim for dim in dv.dims if dim != time_coords.name]
coords = {k: v for k, v in dv.coords.items() if k in dims}
climo = climo.squeeze(axis=0)
Copy link
Collaborator

@tomvothecoder tomvothecoder Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes bug when ncycle == 1 where the climo variable time axis is not being squeezed which causes the rebuilt dv_climo DataArray to fail with

FAILED tests/e3sm_diags/driver/utils/test_climo_xr.py::TestClimo::test_returns_annual_cycle_climatology - ValueError: different number of dimensions on data and dims: 3 vs 2
FAILED tests/e3sm_diags/driver/utils/test_climo_xr.py::TestClimo::test_returns_DJF_season_climatology - ValueError: different number of dimensions on data and dims: 3 vs 2
FAILED tests/e3sm_diags/driver/utils/test_climo_xr.py::TestClimo::test_returns_MAM_season_climatology - ValueError: different number of dimensions on data and dims: 3 vs 2
FAILED tests/e3sm_diags/driver/utils/test_climo_xr.py::TestClimo::test_returns_JJA_season_climatology - ValueError: different number of dimensions on data and dims: 3 vs 2
FAILED tests/e3sm_diags/driver/utils/test_climo_xr.py::TestClimo::test_returns_SON_season_climatology - ValueError: different number of dimensions on data and dims: 3 vs 2
FAILED tests/e3sm_diags/driver/utils/test_climo_xr.py::TestClimo::test_returns_jan_climatology - ValueError: different number of dimensions on data and dims: 3 vs 2
FAILED tests/e3sm_diags/driver/utils/test_climo_xr.py::TestClimo::test_returns_climatology_for_derived_variable - ValueError: different number of dimensions on data and dims: 3 vs 2
FAILED tests/e3sm_diags/driver/utils/test_dataset_xr.py::TestGetClimoDataset::test_returns_climo_dataset_using_climo_of_time_series_files - ValueError: different number of dimensions on data and dims: 3 vs 2

@tomvothecoder
Copy link
Collaborator

tomvothecoder commented Sep 27, 2024

  • Fix integration test failing
  • Address remaining FIXME and TODO: items in arm_diags_driver.py

Regression test results

.png regression test shows all plots are identical except the following, which have missing data (white spaces) compared to main:

  • Investigate the following plots because they have white spaces on dev branch for some reason -- CLOUD test variable -- not a concern, dev branch is doing the right thing by masking closer to the surface at 1000 mb pressure level
Comparing:
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/arm_diags/armdiags-CLOUD-ANNUALCYCLE-nsac1-test.png
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags/armdiags-CLOUD-ANNUALCYCLE-nsac1-test.png
     * Difference path /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags_diff/armdiags-CLOUD-ANNUALCYCLE-nsac1-test.png
Comparing:
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/arm_diags/armdiags-CLOUD-ANNUALCYCLE-sgpc1-test.png
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags/armdiags-CLOUD-ANNUALCYCLE-sgpc1-test.png
     * Difference path /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags_diff/armdiags-CLOUD-ANNUALCYCLE-sgpc1-test.png
Comparing:
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/arm_diags/armdiags-CLOUD-ANNUALCYCLE-twpc1-test.png
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags/armdiags-CLOUD-ANNUALCYCLE-twpc1-test.png
     * Difference path /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags_diff/armdiags-CLOUD-ANNUALCYCLE-twpc1-test.png
Comparing:
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/arm_diags/armdiags-CLOUD-ANNUALCYCLE-twpc2-test.png
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags/armdiags-CLOUD-ANNUALCYCLE-twpc2-test.png
     * Difference path /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags_diff/armdiags-CLOUD-ANNUALCYCLE-twpc2-test.png
Comparing:
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/arm_diags/armdiags-CLOUD-ANNUALCYCLE-twpc3-test.png
    * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags/armdiags-CLOUD-ANNUALCYCLE-twpc3-test.png
     * Difference path /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/667-arm_diags-final/arm_diags_diff/armdiags-CLOUD-ANNUALCYCLE-twpc3-test.png

Copy link
Collaborator

@tomvothecoder tomvothecoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining TODO and FIXME comments in arm_diags_driver.py

e3sm_diags/driver/arm_diags_driver.py Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Show resolved Hide resolved
e3sm_diags/driver/arm_diags_driver.py Outdated Show resolved Hide resolved
@tomvothecoder tomvothecoder changed the title CDAT Migration Phase 2: Refactor arm_diags set (try 2) CDAT Migration Phase 2: Refactor arm_diags set Sep 30, 2024
chengzhuzhang and others added 28 commits October 1, 2024 10:49
- Refactor `arm_diags_driver.py`- rename functions, reorder functions, rename variables for clarity, replace `.format` with f-strings, update docstrings, add `_select_point()`
- Refactor `climo_xr.py` - extract `_get_cycle_for_freq()`, remove commented out code
- Add `time_interval` to `ARMDiagsParameter`
- Rename functions to denote private and reorder based on call in `arm_diags_driver.py`
- Add typestrings and annotations
- Separate logically related blocks of code with comments
- Add `_save_plots()` function to replace repeated I/O across functions
- This behavior mimics the co flag found in the CDAT codebase
Comment on lines +1406 to +1445
def _exclude_sub_monthly_coord_spanning_year(
self, ds_subset: xr.Dataset
) -> xr.Dataset:
"""
Exclude the last time coordinate for sub-monthly data if it extends into
the next year.
Excluding end time coordinates that extend to the next year is
necessary because downstream operations such as annual cycle climatology
should consist of data for full years for accurate calculations.
For example, if the time slice is ("0001-01-01", "0002-01-01") and
the last time coordinate is:
* "0002-01-01" -> exclude
* "0001-12-31" -> don't exclude
Parameters
----------
ds_subset : xr.Dataset
The subsetted dataset.
Returns
-------
xr.Dataset
The dataset with the last time coordinate excluded if necessary.
Notes
-----
This function replicates the CDAT cdms2 "co" slice flag (close, open).
"""
time_dim = xc.get_dim_keys(ds_subset, axis="T")
time_values = ds_subset[time_dim]
last_time_year = time_values[-1].dt.year.item()
second_last_time_year = time_values[-2].dt.year.item()

if self.is_sub_monthly and last_time_year > second_last_time_year:
ds_subset = ds_subset.isel(time=slice(0, -1))

return ds_subset

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replicates the "co" slice flag for sub-monthly data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cdat-migration-fy24 CDAT Migration FY24 Task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants