Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix formatting of cice grid file #18

Merged
merged 7 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions esmgrids/cice_grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import netCDF4 as nc
from warnings import warn

from esmgrids.base_grid import BaseGrid

Expand Down Expand Up @@ -31,7 +32,11 @@ def fromfile(cls, h_grid_def, mask_file=None, description="CICE tripolar"):
area_t = f.variables["tarea"][:]
area_u = f.variables["uarea"][:]

angle_t = np.rad2deg(f.variables["angleT"][:])
try:
angle_t = np.rad2deg(f.variables["anglet"][:])
except KeyError:
angle_t = np.rad2deg(f.variables["angleT"][:])

angle_u = np.rad2deg(f.variables["angle"][:])

if "clon_t" in f.variables:
Expand Down Expand Up @@ -69,12 +74,12 @@ def _create_2d_nc_var(self, f, name):
return f.createVariable(
name,
"f8",
dimensions=("ny", "nx"),
dimensions=("nj", "ni"),
compression="zlib",
complevel=1,
)

def write(self, grid_filename, mask_filename, metadata=None):
def write(self, grid_filename, mask_filename, metadata=None, variant=None):
"""
Write out CICE grid to netcdf

Expand All @@ -88,14 +93,17 @@ def write(self, grid_filename, mask_filename, metadata=None):
Any global or variable metadata attributes to add to the files being written
"""

if variant is not None and variant != "cice5-auscom":
raise NotImplementedError(f"{variant} not recognised")

# Grid file
f = nc.Dataset(grid_filename, "w")

# Create dimensions.
f.createDimension("nx", self.num_lon_points)
# nx is the grid_longitude but doesn't have a value other than its index
f.createDimension("ny", self.num_lat_points)
# ny is the grid_latitude but doesn't have a value other than its index
f.createDimension("ni", self.num_lon_points)
# ni is the grid_longitude but doesn't have a value other than its index
f.createDimension("nj", self.num_lat_points)
# nj is the grid_latitude but doesn't have a value other than its index

# Make all CICE grid variables.
# names are based on https://cfconventions.org/Data/cf-standard-names/current/build/cf-standard-name-table.html
Expand Down Expand Up @@ -135,7 +143,12 @@ def write(self, grid_filename, mask_filename, metadata=None):
angle.standard_name = "angle_of_rotation_from_east_to_x"
angle.coordinates = "ulat ulon"
angle.grid_mapping = "crs"
angleT = self._create_2d_nc_var(f, "angleT")

if variant == "cice5-auscom":
angleT = self._create_2d_nc_var(f, "angleT")
elif variant is None:
angleT = self._create_2d_nc_var(f, "anglet")

angleT.units = "radians"
angleT.long_name = "Rotation angle of T cells."
angleT.standard_name = "angle_of_rotation_from_east_to_x"
Expand Down Expand Up @@ -185,8 +198,8 @@ def write(self, grid_filename, mask_filename, metadata=None):
# Mask file
f = nc.Dataset(mask_filename, "w")

f.createDimension("nx", self.num_lon_points)
f.createDimension("ny", self.num_lat_points)
f.createDimension("ni", self.num_lon_points)
f.createDimension("nj", self.num_lat_points)
mask = self._create_2d_nc_var(f, "kmt")
mask.grid_mapping = "crs"
mask.standard_name = "sea_binary_mask"
Expand Down
6 changes: 4 additions & 2 deletions esmgrids/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ def cice_from_mom():
parser.add_argument("--ocean_mask", type=str, help="Input MOM ocean_mask.nc mask file")
parser.add_argument("--cice_grid", type=str, default="grid.nc", help="Output CICE grid file")
parser.add_argument("--cice_kmt", type=str, default="kmt.nc", help="Output CICE kmt file")
parser.add_argument("--cice_variant", type=str, default=None, help="Cice variant")

args = parser.parse_args()
ocean_hgrid = os.path.abspath(args.ocean_hgrid)
ocean_mask = os.path.abspath(args.ocean_mask)
cice_grid = os.path.abspath(args.cice_grid)
cice_kmt = os.path.abspath(args.cice_kmt)
cice_variant = args.cice_variant

version = safe_version()
runcmd = (
f"Created using https://github.com/COSIMA/esmgrids {version}: "
f"cice_from_mom --ocean_hgrid={ocean_hgrid} --ocean_mask={ocean_mask} "
f"--cice_grid={cice_grid} --cice_kmt={cice_kmt}"
f"--cice_grid={cice_grid} --cice_kmt={cice_kmt} --cice_variant={cice_variant}"
)
provenance_metadata = {
"inputfile": (
Expand All @@ -37,4 +39,4 @@ def cice_from_mom():

mom = MomGrid.fromfile(ocean_hgrid, mask_file=ocean_mask)
cice = CiceGrid.fromgrid(mom)
cice.write(cice_grid, cice_kmt, metadata=provenance_metadata)
cice.write(cice_grid, cice_kmt, metadata=provenance_metadata, variant=cice_variant)
165 changes: 101 additions & 64 deletions test/test_cice_grid.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import pytest
import xarray as xr
import warnings
from numpy.testing import assert_allclose
from numpy import deg2rad
from subprocess import run
from pathlib import Path

# from esmgrids.cli import cice_from_mom
from esmgrids.mom_grid import MomGrid
from esmgrids.cice_grid import CiceGrid

# create test grids at 4 degrees and 0.1 degrees
# 4 degress is the lowest tested in ocean_model_grid_generator
# going higher resolution than 0.1 has too much computational cost
_test_resolutions = [4, 0.1]

_variants = ["cice5-auscom", None]

# so that our fixtures are only create once in this pytest module, we need this special version of 'tmp_path'

# so that our fixtures are only created once in this pytest module, we need this special version of 'tmp_path'
@pytest.fixture(scope="module")
def tmp_path(tmp_path_factory: pytest.TempdirFactory) -> Path:
return tmp_path_factory.mktemp("temp")
Expand Down Expand Up @@ -53,39 +57,31 @@ def __init__(self, res, tmp_path):
class CiceGridFixture:
"""Make the CICE grid, using script under test"""

def __init__(self, mom_grid, tmp_path):
def __init__(self, mom_grid, tmp_path, variant):
self.path = str(tmp_path) + "/grid.nc"
self.kmt_path = str(tmp_path) + "/kmt.nc"
run(
[
"cice_from_mom",
"--ocean_hgrid",
mom_grid.path,
"--ocean_mask",
mom_grid.mask_path,
"--cice_grid",
self.path,
"--cice_kmt",
self.kmt_path,
]
)
self.ds = xr.open_dataset(self.path, decode_cf=False)
self.kmt_ds = xr.open_dataset(self.kmt_path, decode_cf=False)

run_cmd = [
"cice_from_mom",
"--ocean_hgrid",
mom_grid.path,
"--ocean_mask",
mom_grid.mask_path,
"--cice_grid",
self.path,
"--cice_kmt",
self.kmt_path,
]
if variant is not None:
run_cmd.append("--cice_variant")
run_cmd.append(variant)
run(run_cmd)

# pytest doesn't support class fixtures, so we need these two constructor funcs
@pytest.fixture(scope="module", params=_test_resolutions)
def mom_grid(request, tmp_path):
return MomGridFixture(request.param, tmp_path)


@pytest.fixture(scope="module")
def cice_grid(mom_grid, tmp_path):
return CiceGridFixture(mom_grid, tmp_path)
self.ds = xr.open_dataset(self.path, decode_cf=False)
self.kmt_ds = xr.open_dataset(self.kmt_path, decode_cf=False)


@pytest.fixture(scope="module")
def test_grid_ds(mom_grid):
def gen_grid_ds(mom_grid, variant):
# this generates the expected answers
# In simple terms the MOM supergrid has four cells for each model grid cell. The MOM supergrid includes all edges (left and right) but CICE only uses right/east edges. (e.g. For points/edges of first cell: 0,0 is SW corner, 1,1 is middle of cell, 2,2, is NE corner/edges)

Expand All @@ -105,7 +101,10 @@ def test_grid_ds(mom_grid):
test_grid["tlon"] = deg2rad(t_points.x)

test_grid["angle"] = deg2rad(u_points.angle_dx) # angle at u point
test_grid["angleT"] = deg2rad(t_points.angle_dx)
if variant == "cice5-auscom":
test_grid["angleT"] = deg2rad(t_points.angle_dx)
else: # cice6
test_grid["anglet"] = deg2rad(t_points.angle_dx)

# length of top (northern) edge of cells
test_grid["htn"] = ds.dx.isel(nyp=slice(2, None, 2)).coarsen(nx=2).sum() * 100
Expand All @@ -128,33 +127,57 @@ def test_grid_ds(mom_grid):
return test_grid


# pytest doesn't support class fixtures, so we need these two constructor funcs
@pytest.fixture(scope="module", params=_test_resolutions)
def mom_grid(request, tmp_path):
return MomGridFixture(request.param, tmp_path)


# the variant neews to be the same for both the cice_grid and the test_grid, so bundle them
@pytest.fixture(scope="module", params=_variants)
def grids(request, mom_grid, tmp_path):
return {"cice": CiceGridFixture(mom_grid, tmp_path, request.param), "test_ds": gen_grid_ds(mom_grid, request.param)}


# ----------------
# the tests in earnest:


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_cice_var_list(cice_grid, test_grid_ds):
def test_cice_var_list(grids):
# Test : Are there missing vars in cice_grid?
assert set(test_grid_ds.variables).difference(cice_grid.ds.variables) == set()
assert set(grids["test_ds"].variables).difference(grids["cice"].ds.variables) == set()


def test_cice_dims(grids):
# Test : Are the dim names consistent with cice history output?
assert set(grids["cice"].ds.dims) == set(
["ni", "nj"]
), "cice dimension names should be 'ni','nj' to be consistent with history output"
assert grids["cice"].ds.sizes["ni"] == len(grids["test_ds"].nx)
assert grids["cice"].ds.sizes["nj"] == len(grids["test_ds"].ny)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_cice_grid(cice_grid, test_grid_ds):
def test_cice_grid(grids):
# Test : Is the data the same as the test_grid
for jVar in test_grid_ds.variables:
assert_allclose(cice_grid.ds[jVar], test_grid_ds[jVar], rtol=1e-13, verbose=True, err_msg=f"{jVar} mismatch")
for jVar in grids["test_ds"].variables:
assert_allclose(
grids["cice"].ds[jVar], grids["test_ds"][jVar], rtol=1e-13, verbose=True, err_msg=f"{jVar} mismatch"
)


def test_cice_kmt(mom_grid, cice_grid):
def test_cice_kmt(mom_grid, grids):
# Test : does the mask match
mask = mom_grid.mask_ds.mask
kmt = cice_grid.kmt_ds.kmt
kmt = grids["cice"].kmt_ds.kmt

assert_allclose(mask, kmt, rtol=1e-13, verbose=True, err_msg="mask mismatch")


def test_cice_grid_attributes(cice_grid):
def test_cice_grid_attributes(grids):
# Test: do the expected attributes to exist in the cice ds
# To-do: rewrite test using the CF-checker (or similar)
cf_attributes = {
"ulat": {"standard_name": "latitude", "units": "radians"},
"ulon": {"standard_name": "longitude", "units": "radians"},
Expand Down Expand Up @@ -184,48 +207,62 @@ def test_cice_grid_attributes(cice_grid):
"grid_mapping": "crs",
"coordinates": "tlat tlon",
},
"anglet": {
"standard_name": "angle_of_rotation_from_east_to_x",
"units": "radians",
"grid_mapping": "crs",
"coordinates": "tlat tlon",
},
"htn": {"units": "cm", "coordinates": "ulat tlon", "grid_mapping": "crs"},
"hte": {"units": "cm", "coordinates": "tlat ulon", "grid_mapping": "crs"},
}

for iVar in cf_attributes.keys():
print(cice_grid.ds[iVar])

for jAttr in cf_attributes[iVar].keys():
assert cice_grid.ds[iVar].attrs[jAttr] == cf_attributes[iVar][jAttr]
for iVar in grids["cice"].ds.keys():
if iVar != "crs": # test seperately
for jAttr in cf_attributes[iVar].keys():
assert grids["cice"].ds[iVar].attrs[jAttr] == cf_attributes[iVar][jAttr]


def test_crs_exist(cice_grid):
def test_crs_exist(grids):
# Test: has the crs been added ?
# todo: open with GDAL and rioxarray and confirm they find the crs?
assert hasattr(cice_grid.ds, "crs")
assert hasattr(cice_grid.kmt_ds, "crs")
assert hasattr(grids["cice"].ds, "crs")
assert hasattr(grids["cice"].kmt_ds, "crs")


def test_inputs_logged(cice_grid, mom_grid):
def test_inputs_logged(grids, mom_grid):
# Test: have the source data been logged ?

input_md5 = run(["md5sum", cice_grid.ds.inputfile], capture_output=True, text=True)
input_md5 = run(["md5sum", mom_grid.path], capture_output=True, text=True)
input_md5 = input_md5.stdout.split(" ")[0]
mask_md5 = run(["md5sum", cice_grid.kmt_ds.inputfile], capture_output=True, text=True)
mask_md5 = run(["md5sum", mom_grid.mask_path], capture_output=True, text=True)
mask_md5 = mask_md5.stdout.split(" ")[0]

for ds in [cice_grid.ds, cice_grid.kmt_ds]:
assert (
ds.inputfile
== (
mom_grid.path
+ " (md5 hash: "
+ input_md5
+ "), "
+ mom_grid.mask_path
+ " (md5 hash: "
+ mask_md5
+ ")"
),
"inputfile attribute incorrect ({ds.inputfile} != {mom_grid.path})",
)
for ds in [grids["cice"].ds, grids["cice"].kmt_ds]:
assert ds.inputfile == (
mom_grid.path + " (md5 hash: " + input_md5 + "), " + mom_grid.mask_path + " (md5 hash: " + mask_md5 + ")"
), "inputfile attribute incorrect ({ds.inputfile} != {mom_grid.path})"

assert hasattr(ds, "inputfile"), "inputfile attribute missing"

assert hasattr(ds, "history"), "history attribute missing"


def test_variant(mom_grid, tmp_path):
# Is a error given for variant not equal to None or 'cice5-auscom'

mom = MomGrid.fromfile(mom_grid.path, mask_file=mom_grid.mask_path)
cice = CiceGrid.fromgrid(mom)

with pytest.raises(NotImplementedError, match="andrew not recognised"):
cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc", variant="andrew")

try:
cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc", variant="cice5-auscom")
except:
assert False, "Failed to write cice grid with valid input arguments provided"

try:
cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc")
except:
assert False, "Failed to write cice grid with 'None' variant"
Loading