Skip to content

Commit

Permalink
Add the ability to decide whether we predict precipitation or evapora…
Browse files Browse the repository at this point in the history
…tion (and diagnose the other). Before we could only diagnose precipitation and predict evaporation.

PiperOrigin-RevId: 676049334
  • Loading branch information
yaniyuval authored and NeuralGCM authors committed Sep 18, 2024
1 parent c39cb12 commit 9d89731
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions neuralgcm/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

TransformModule = typing.TransformModule

PRECIPITATION = 'precipitation'
EVAPORATION = 'evaporation'


class DiagnosticFn(Protocol):
"""Implements initialization and computation of model diagnostic fields."""
Expand Down Expand Up @@ -285,11 +288,11 @@ def __init__(
'specific_cloud_ice_water_content',
'specific_cloud_liquid_water_content',
),
feature_name: str = 'evaporation',
is_precipitation: bool = True,
method_precipitation: str = 'cumulative',
method_evaporation: str = 'rate',
name: Optional[str] = None,
field_name: str = 'precipitation_cumulative',
field_name: str = 'precipitation_cumulative_mean',
):
# del aux_features
super().__init__(name=name)
Expand All @@ -300,8 +303,20 @@ def __init__(
self.method_precipitation = method_precipitation
self.method_evaporation = method_evaporation
self.to_nodal_fn = coords.horizontal.to_nodal
self.is_precipitation = is_precipitation
if self.is_precipitation:
predicted_name = PRECIPITATION
diagnosed_name = EVAPORATION
else:
predicted_name = EVAPORATION
diagnosed_name = PRECIPITATION

output_shapes = {f'{feature_name}': np.asarray(coords.surface_nodal_shape)}
self.predicted_name = predicted_name
self.diagnosed_name = diagnosed_name

output_shapes = {
f'{predicted_name}': np.asarray(coords.surface_nodal_shape)
}

self.embedding_fn = embedding_module(
coords, dt, physics_specs, aux_features, output_shapes=output_shapes
Expand All @@ -321,13 +336,16 @@ def __call__(
e_minus_p = self._compute_evaporation_minus_precipitation(
model_state, physics_tendencies
)
evaporation = self.embedding_fn(
water_budget = self.embedding_fn(
model_state.state,
model_state.memory,
model_state.diagnostics,
model_state.randomness,
forcing,
)
water_budget[self.diagnosed_name] = (
-e_minus_p - water_budget[self.predicted_name]
)

# Note: In ERA5 mean_evaporation_rate (kg m**-2 s**-1)
# is negative for evaporation.
Expand All @@ -336,15 +354,16 @@ def __call__(
output_dict = {}
surface_nodal_shape = self.coords.horizontal.nodal_shape
if self.method_precipitation == 'rate': # units: length/time
output_dict['precipitation_rate'] = (
(-e_minus_p - evaporation['evaporation']) / self.water_density
output_dict[PRECIPITATION + '_rate'] = (
(water_budget[PRECIPITATION]) / self.water_density
)
elif self.method_precipitation == 'cumulative': # units: length
previous = model_state.diagnostics.get(
self.field_name, jnp.zeros(surface_nodal_shape)
)
output_dict[self.field_name] = previous - (
((e_minus_p + evaporation['evaporation']) / self.water_density)
assert self.field_name == 'precipitation_cumulative_mean', self.field_name
output_dict[self.field_name] = previous + (
(water_budget[PRECIPITATION] / self.water_density)
* self.dt
)
else:
Expand All @@ -353,14 +372,14 @@ def __call__(
' be `rate`/`cumulative`'
)
if self.method_evaporation == 'rate': # units: mass length**-2 time**-1
output_dict['evaporation'] = evaporation['evaporation']
output_dict[EVAPORATION] = water_budget[EVAPORATION]
elif self.method_evaporation == 'cumulative': # units: length
previous_evap = model_state.diagnostics.get(
'evaporation_cumulative', jnp.zeros(surface_nodal_shape)
EVAPORATION + '_cumulative', jnp.zeros(surface_nodal_shape)
)
output_dict['evaporation_cumulative'] = (
output_dict[EVAPORATION + '_cumulative'] = (
previous_evap
+ (evaporation['evaporation'] / self.water_density) * self.dt
+ (water_budget[EVAPORATION] / self.water_density) * self.dt
)
else:
raise ValueError(
Expand Down

0 comments on commit 9d89731

Please sign in to comment.