diff --git a/test/test_um2netcdf.py b/test/test_um2netcdf.py index 92bb268..a88811c 100644 --- a/test/test_um2netcdf.py +++ b/test/test_um2netcdf.py @@ -15,6 +15,8 @@ import iris.coords import iris.exceptions +from netCDF4 import default_fillvals + D_LAT_N96 = 1.25 # Degrees between latitude points on N96 grid D_LON_N96 = 1.875 # Degrees between longitude points on N96 grid @@ -1207,3 +1209,56 @@ def test_convert_32_bit_with_float64(ua_plev_cube): ua_plev_cube.data = array um2nc.convert_32_bit(ua_plev_cube) assert ua_plev_cube.data.dtype == np.float32 + + +@pytest.mark.parametrize( + "cube_data, custom_fill_val, expected_fill_val", + [ + (np.array([1.1, 2.1], dtype="float32"), + None, + np.float32(um2nc.DEFAULT_FILL_VAL_FLOAT)), + (np.array([1.1, 2.1], dtype="float64"), + None, + np.float64(um2nc.DEFAULT_FILL_VAL_FLOAT)), + (np.array([1.1, 2.1], dtype="complex64"), + None, + np.complex64(default_fillvals["c8"])), + (np.array([1, 2], dtype="int32"), + None, + np.int32(default_fillvals["i4"])), + (np.array([1, 2], dtype=np.dtype("int64")), + None, + np.int64(default_fillvals["i8"])), + (np.array([1, 2], dtype=np.dtype("int64")), + -12345, + np.int64(-12345)) + ] +) +def test_fix_fill_value(cube_data, custom_fill_val, expected_fill_val): + """ + Check that correct default and custom fill values are added based + on a cube's data's type. + """ + fake_cube = DummyCube(12345, "fake_var", attributes={}) + fake_cube.data = cube_data + + um2nc.fix_fill_value(fake_cube, custom_fill_val) + + cube_fill_val = fake_cube.attributes["missing_value"] + + assert cube_fill_val[0] == expected_fill_val + # Check new fill value type matches cube's data's type + assert cube_fill_val.dtype == cube_data.dtype + + +def test_fix_fill_value_wrong_type(): + """ + Check that an error is raised when a custom fill value's + type does not match the cube's data's type. + """ + fake_cube = DummyCube(12345, "fake_var", attributes={}) + fake_cube.data = np.array([1, 2, 3], dtype="int32") + custom_fill_val = np.float32(151.11) + + with pytest.raises(TypeError): + um2nc.fix_fill_value(fake_cube, custom_fill_val) diff --git a/umpost/um2netcdf.py b/umpost/um2netcdf.py index de886d3..b0f52d6 100644 --- a/umpost/um2netcdf.py +++ b/umpost/um2netcdf.py @@ -70,6 +70,8 @@ LEVEL_HEIGHT = "level_height" SIGMA = "sigma" +DEFAULT_FILL_VAL_FLOAT = 1.e20 + class PostProcessingError(Exception): """Generic class for um2nc specific errors.""" @@ -319,6 +321,35 @@ def fix_latlon_coords(cube, grid_type, dlat, dlon): fix_lon_coord_name(longitude_coordinate, grid_type, dlon) +def fix_fill_value(cube, custom_fill_val=None): + """ + Set a cube's missing_value attribute according to the data's dtype. + + Parameters + ---------- + cube: Iris cube object (modified in place). + custom_fill_val: Optional custom fill value. Type should match + the cube data's type. + """ + if custom_fill_val is not None: + if type(custom_fill_val) == cube.data.dtype: + fill_value = custom_fill_val + else: + msg = (f"custom_fill_val type {type(custom_fill_val)} does not " + f"match cube {cube.name()} data type {cube.data.dtype}.") + raise TypeError(msg) + + elif cube.data.dtype.kind == 'f': + fill_value = DEFAULT_FILL_VAL_FLOAT + else: + fill_value = default_fillvals[ + f"{cube.data.dtype.kind:s}{cube.data.dtype.itemsize:1d}" + ] + + # Use an array to force the type to match the data type + cube.attributes['missing_value'] = np.array([fill_value], cube.data.dtype) + + # TODO: split cube ops into functions, this will likely increase process() workflow steps def cubewrite(cube, sman, compression, use64bit, verbose): # TODO: move into process() AND if a new cube is returned, swap into filtered cube list @@ -328,15 +359,7 @@ def cubewrite(cube, sman, compression, use64bit, verbose): if not use64bit: convert_32_bit(cube) - # Set the missing_value attribute. Use an array to force the type to match - # the data type - if cube.data.dtype.kind == 'f': - fill_value = 1.e20 - else: - # Use netCDF defaults - fill_value = default_fillvals['%s%1d' % (cube.data.dtype.kind, cube.data.dtype.itemsize)] - - cube.attributes['missing_value'] = np.array([fill_value], cube.data.dtype) + fix_fill_value(cube) # If reference date is before 1600 use proleptic gregorian # calendar and change units from hours to days