Skip to content

Commit

Permalink
Merge pull request #69 from oloapinivad/fix/atos
Browse files Browse the repository at this point in the history
various bug fix found on atos and improve support for globo
  • Loading branch information
oloapinivad authored Mar 15, 2023
2 parents 9ad612b + f7a2581 commit 9419cd5
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 379 deletions.
2 changes: 1 addition & 1 deletion config.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ global:
rsns, rlns, hfss, hfls, net_sfc_nosn, net_sfc,
toamsfc_nosn, toamsfc]
oce_vars: [tos, sos, zos, wfo]
ice_vars: [siconc_north, siconc_south]
ice_vars: [siconc]
tab_vars: [rsnt, rlnt, net_toa, rsns, rlns, hfss, hfls, swcf, lwcf, net_sfc, toamsfc,
tas, clt, cll, clm, clh, pr, pme, psl]
regions: ['Global', 'North Midlat', 'Tropical', 'South Midlat']
Expand Down
125 changes: 64 additions & 61 deletions ecmean/global_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import yaml
from ecmean.libs.diagnostic import Diagnostic
from ecmean.libs.general import weight_split, write_tuning_table, get_domain, \
numeric_loglevel, check_time_axis, dict_to_dataframe, init_mydict
numeric_loglevel, check_time_axis, dict_to_dataframe, init_mydict, check_interface
from ecmean.libs.files import var_is_there, get_inifiles, load_yaml, make_input_filename
from ecmean.libs.formula import formula_wrapper
from ecmean.libs.masks import masked_meansum, select_region
Expand Down Expand Up @@ -54,85 +54,88 @@ def gm_worker(util, ref, face, diag, varmean, vartrend, varlist):

for var in varlist:

# get domain
domain = get_domain(var, face)

# compute weights
weights = getattr(util, domain + 'area')
domain_mask = getattr(util, domain + 'mask')

# get input files/fielf
infile = make_input_filename(var, face, diag)

# check if variables are available
isavail, varunit = var_is_there(infile, var, face)

# create empty nested dictionaries
result = init_mydict(diag.seasons, diag.regions)
trend = init_mydict(diag.seasons, diag.regions)

if isavail:
# check if the variable is in the interface file
if check_interface(var, face):

# perform the unit conversion extracting offset and factor
units_handler = UnitsHandler(var, org_units=varunit, clim=ref, face=face)
offset, factor = units_handler.offset, units_handler.factor
# get domain
domain = get_domain(var, face)

# load the object
if not isinstance(infile, (xr.DataArray, xr.Dataset)):
xfield = xr.open_mfdataset(infile, preprocess=xr_preproc, chunks={'time': 12})
else:
xfield = infile
# compute weights
weights = getattr(util, domain + 'area')
domain_mask = getattr(util, domain + 'mask')

# in case of big files with multi year, be sure of having opened the right records
xfield = xfield.sel(time=xfield.time.dt.year.isin(diag.years_joined))
# get input files/fielf
infile = make_input_filename(var, face, diag)

# check time axis
check_time_axis(xfield.time, diag.years_joined)
# check if variables are available
isavail, varunit = var_is_there(infile, var, face)

# get the data-array field for the required var
cfield = formula_wrapper(var, face, xfield).compute()
if isavail:

for season in diag.seasons:
# perform the unit conversion extracting offset and factor
units_handler = UnitsHandler(var, org_units=varunit, clim=ref, face=face)
offset, factor = units_handler.offset, units_handler.factor

# copy of the full field
tfield = cfield.copy(deep=True)
# load the object
if not isinstance(infile, (xr.DataArray, xr.Dataset)):
xfield = xr.open_mfdataset(infile, preprocess=xr_preproc, chunks={'time': 12})
else:
xfield = infile

if season != 'ALL':
tfield = tfield.sel(time=cfield.time.dt.season.isin(season))
# in case of big files with multi year, be sure of having opened the right records
xfield = xfield.sel(time=xfield.time.dt.year.isin(diag.years_joined))

if diag.ftrend:
# this does not consider continuous seasons for DJF, but JF+D
tfield = tfield.groupby('time.year').mean('time')
else:
tfield = tfield.mean(dim='time')
# check time axis
check_time_axis(xfield.time, diag.years_joined)

for region in diag.regions:
# get the data-array field for the required var
cfield = formula_wrapper(var, face, xfield).compute()

slicefield = select_region(tfield, region)
sliceweights = select_region(weights, region)
if isinstance(domain_mask, xr.DataArray):
slicemask = select_region(domain_mask, region)
else:
slicemask = 0.
for season in diag.seasons:

# final operation on the field
a = masked_meansum(
xfield=slicefield, weights=sliceweights, mask=slicemask,
operation=ref[var].get('operation', 'mean'),
mask_type=ref[var].get('mask', 'global'),
domain=domain)
# copy of the full field
tfield = cfield.copy(deep=True)

try:
x = a.compute()
except BaseException:
x = a
result[season][region] = float((np.nanmean(x) + offset) * factor)
if season != 'ALL':
tfield = tfield.sel(time=cfield.time.dt.season.isin(season))

if diag.ftrend:
if (len(x) == len(diag.years_joined)):
trend[season][region] = np.polyfit(diag.years_joined, x, 1)[0]
if diag.fverb and season == 'ALL' and region == 'Global':
print('Average', var, season, region, result[season][region])
# this does not consider continuous seasons for DJF, but JF+D
tfield = tfield.groupby('time.year').mean('time')
else:
tfield = tfield.mean(dim='time')

for region in diag.regions:

slicefield = select_region(tfield, region)
sliceweights = select_region(weights, region)
if isinstance(domain_mask, xr.DataArray):
slicemask = select_region(domain_mask, region)
else:
slicemask = 0.

# final operation on the field
a = masked_meansum(
xfield=slicefield, weights=sliceweights, mask=slicemask,
operation=ref[var].get('operation', 'mean'),
mask_type=ref[var].get('mask', 'global'),
domain=domain)

try:
x = a.compute()
except BaseException:
x = a
result[season][region] = float((np.nanmean(x) + offset) * factor)

if diag.ftrend:
if (len(x) == len(diag.years_joined)):
trend[season][region] = np.polyfit(diag.years_joined, x, 1)[0]
if diag.fverb and season == 'ALL' and region == 'Global':
print('Average', var, season, region, result[season][region])

# nested dictionary, to be redifend as a dict to remove lambdas
varmean[var] = result
Expand Down
2 changes: 1 addition & 1 deletion ecmean/interfaces/interface_EC-Earth4.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ component:
nemo:
maskfile: ''
gridfile: ''
areafile: 'nemo-initial-state.nc'
areafile: 'domain_cfg.nc'

# Definition of individual variables
variables:
Expand Down
14 changes: 11 additions & 3 deletions ecmean/interfaces/interface_GLOBO.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ model:
basedir: '{expname}'
component:
atm: globo
oce: fesom
oce: nemo

# Definition of filetypes (specific for each component)
# Each variable belongs to a filetype
Expand All @@ -39,9 +39,9 @@ filetype:
component:
globo:
maskfile: 'GLOBONE_BCs.nc'
gridfile: 'GLOBONE_atm_6hrs_{year1}.nc'
gridfile: ''
areafile: ''
fesom:
nemo:
maskfile: ''
gridfile: ''
areafile: ''
Expand Down Expand Up @@ -145,3 +145,11 @@ variables:
varname: 'Precip. minus evap. (ocean)'
derived: 'pr+evspsbl'
filetype: atm_6hrs
siconc:
varname: 'Sea-Ice Concentration'
derived: sic
filetype: atm_6hrs
tos:
varname: 'SST'
filetype: atm_6hrs

12 changes: 11 additions & 1 deletion ecmean/libs/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def check_time_axis(xtime, years):
"""Check if we have 12 months per year and if the required years
have been found in the NetCDF files. """

unique, counts = np.unique(xtime.dt.month, return_counts=True)
#unique, counts = np.unique(xtime.dt.month, return_counts=True)
unique, counts = np.unique(xtime.time.resample(time='1M').mean(), return_counts=True)
if len(unique) != 12 or not all(counts == counts[0]):
logging.warning('Check your data: some months might be missing...')

Expand Down Expand Up @@ -103,6 +104,15 @@ def weight_split(a, n):

return olists

def check_interface(var, face):
"""Check if a var is defined in the interface file"""

if var in face['variables']:
return True
else:
logging.warning(f'{var} is not defined in the interface file, skipping it!')
return False


def get_domain(var, face):
"""Given a variable var extract its domain (oce or atm) from the interface.
Expand Down
41 changes: 6 additions & 35 deletions ecmean/libs/ncfixers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def xr_preproc(ds):
"values": "cell",
}

# safe check for NEMO output in domain_cfg.nc
if 'nav_lon' in ds.data_vars and 'nav_lat' in ds.data_vars:
ds = ds.set_coords(['nav_lon', 'nav_lat'])

# compact renaming
for old_name in [name for name in rename_dict if name in ds.dims or name in ds.coords]:
ds = ds.rename({old_name: rename_dict[old_name]})
Expand All @@ -31,49 +35,16 @@ def xr_preproc(ds):
for h in ['lon', 'lat']:
for f in ['', 'grid_T']:
g = 'nav_' + h + '_' + f
if g in list(ds.coords):
if g in ds.coords:
ds = ds.rename({g: h})

# fix for NEMO eORCA grid (x_grid_T, etc.)
for h in ['x', 'y']:
for f in ['grid_T']:
g = h + '_' + f
if g in list(ds.dims):
if g in ds.dims:
ds = ds.rename({g: h})

# # print(ds)
# if 'time_counter' in list(ds.dims):
# ds = ds.rename({"time_counter": "time"})

# if 'time_counter' in list(ds.coords):
# ds = ds.rename({"time_counter": "time"})

# if 'pressure_levels' in list(ds.coords):
# ds = ds.rename({"pressure_levels": "plev"})

# if 'plevel' in list(ds.dims):
# ds = ds.rename({"plevel": "plev"})

# if 'longitude' in list(ds.dims):
# ds = ds.rename({"longitude": "lon"})

# if 'latitude' in list(ds.dims):
# ds = ds.rename({"latitude": "lat"})

# if 'longitude' in list(ds.coords):
# ds = ds.rename({"longitude": "lon"})

# if 'nav_lon' in list(ds.coords):
# ds = ds.rename({"nav_lon": "lon"})

# if 'latitude' in list(ds.coords):
# ds = ds.rename({"latitude": "lat"})

# if 'nav_lat' in list(ds.coords):
# ds = ds.rename({"nav_lat": "lat"})

# if 'values' in list(ds.dims):
# ds = ds.rename({"values": "cell"})

return ds

Expand Down
4 changes: 2 additions & 2 deletions ecmean/libs/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def heatmap_comparison_pi(relative_table, diag, filemap):
chart = sns.heatmap(myfield, norm=divnorm, cmap=pal,
cbar_kws={"ticks": tictoc, 'label': title},
ax=axs, annot=True, linewidth=0.5, fmt='.2f',
annot_kws={'fontsize': 11, 'fontweight': 'bold'})
annot_kws={'fontsize': 14, 'fontweight': 'bold'})
chart = chart.set_facecolor('whitesmoke')
axs.set_title(f'{title} {diag.modelname} {diag.expname} {diag.year1} {diag.year2}', fontsize=25)
axs.vlines(list(range(sss, tot + sss, sss)), ymin=-1, ymax=len(myfield.index), colors='k')
Expand Down Expand Up @@ -83,7 +83,7 @@ def heatmap_comparison_gm(data_table, mean_table, std_table, diag, filemap):
sss = (len(set([tup[1] for tup in clean.columns])))

chart = sns.heatmap(clean, annot=data_table[mask], vmin=-thr - 0.5, vmax=thr + 0.5, center=0,
annot_kws={'va': 'bottom', 'fontsize': 12},
annot_kws={'va': 'bottom', 'fontsize': 14},
cbar_kws={'ticks': tictoc,
'label': 'Model Bias \n (standard deviation of interannual variability from observations)'},
fmt='.2f', cmap=pal)
Expand Down
Loading

0 comments on commit 9419cd5

Please sign in to comment.