Skip to content

Commit

Permalink
Merge branch 'feature/validation'
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Feb 15, 2024
2 parents 5e79ae6 + 2fd2cc1 commit e703d6d
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 15 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ dependencies = [
"pre-commit",
"xarray",
"netcdf4==1.6.3",
"f90nml",
"GitPython",
]

[tool.setuptools]
Expand All @@ -52,3 +54,4 @@ tcn-ci = "tcn.ci.dispatch:cli"
tcn-hws = "tcn.hws.cli:cli"
tcn-fpy = "tcn.py_ftn_interface.cli:cli"
tcn-plots = "tcn.plots.cli:cli"
tcn-validation = "tcn.validation.cli:cli"
43 changes: 43 additions & 0 deletions src/tcn/validation/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import xarray as xr
import plotly.express as px
from typing import Optional
import numpy as np


def analysis(
ref_dataset: xr.Dataset,
cpu_dataset: xr.Dataset,
variable: Optional[str],
time: int = 0,
):
for name in list(ref_dataset.keys()):
if variable and variable == name:
ref_var = ref_dataset[name].isel(time=time)
cpu_var = cpu_dataset[name].isel(time=time)
diff = (cpu_var - ref_var).values.flatten()
diff = diff[~np.isnan(diff)]
print(f"{name}:\n " f"Max: {diff.max():.2f}\n" f" Min: {diff.min():.2f}")

var_name = ref_var.attrs["long_name"].replace("_", " ").title()
fig = px.histogram(
x=diff,
log_y=True,
)
fig.update_layout(
title=f"{var_name} ({name})",
xaxis_title=f"Difference in {ref_var.attrs['units']}",
)
fig.write_image(f"{name}_hist.png")


if __name__ == "__main__":
import sys

ref_dataset = xr.open_mfdataset(sys.argv[1])
cpu_dataset = xr.open_mfdataset(sys.argv[2])
var = sys.argv[3]
analysis(
ref_dataset,
cpu_dataset,
variable=var,
)
51 changes: 51 additions & 0 deletions src/tcn/validation/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import click
from tcn.validation.analysis import analysis
import tcn.validation.serialbox.serialbox_dat_to_netcdf as sdnc
import xarray as xr


@click.group()
def cli():
pass


@click.command()
@click.argument("reference_nc4", type=str)
@click.argument("computed_nc4", type=str)
@click.argument("variable", type=str)
@click.option("--select_time", "-st", type=int, default=0)
def validate(
reference_nc4: str,
computed_nc4: str,
variable: str,
select_time: int = 0,
):
analysis(
ref_dataset=xr.open_mfdataset(reference_nc4),
cpu_dataset=xr.open_mfdataset(computed_nc4),
variable=variable,
time=select_time,
)


@click.command()
@click.argument("data_path_of_dat", type=str)
@click.argument("output_path", type=str)
@click.option("--rank", "-r", type=int, default=-1)
@click.option("--savepoint", "-s", type=int, default=-1)
def serialbox(
data_path_of_dat: str,
output_path: str,
rank: int,
savepoint: int,
):
sdnc.main(
data_path=data_path_of_dat,
output_path=output_path,
do_only_rank=rank,
do_only_savepoint=savepoint,
)


cli.add_command(validate)
cli.add_command(serialbox)
74 changes: 74 additions & 0 deletions src/tcn/validation/geos_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import dataclasses
from dataclasses import dataclass
from git import Repo
import yaml
import pathlib
from typing import List, Optional


@dataclass
class RepositoryStatus:
name: str
hexsha: str
tag: Optional[str] = None


@dataclass
class GEOSStatus:
repositories: List[RepositoryStatus] = dataclasses.field(default_factory=list)

def __eq__(self, other: object) -> bool:
if not isinstance(other, GEOSStatus):
raise ValueError("Need to == with another GEOSStatus")
for r_status in self.repositories:
# Check names & hashes
if (
len(
[
other_status.name
for other_status in other.repositories
if other_status.name == r_status.name
and other_status.hexsha == r_status.hexsha
]
)
== 0
):
return False

return True


def _get_all_repo_status(
mepo_components_path: str, verbose: bool = False
) -> GEOSStatus:
geos_dir = pathlib.Path(mepo_components_path).parent.resolve()
with open(mepo_components_path) as f:
comps = yaml.safe_load(f)
all_repos: List[RepositoryStatus] = []
for comp, config in comps.items():
if "local" in config.keys():
r = Repo(f"{geos_dir}/{config['local']}")
hexsha = r.head.commit.hexsha
tag = None
for t in r.tags:
if t.commit.hexsha == hexsha:
tag = t.name
break
tag_as_str = ""
if tag:
tag_as_str = f" (tag: {tag})"
all_repos.append(RepositoryStatus(comp, r.head.commit.hexsha, tag_as_str))
if verbose:
print(f"{comp:<25}{r.head.commit.hexsha}{tag_as_str}")
return GEOSStatus(all_repos)


if __name__ == "__main__":
geos_mepo_components = "/home/fgdeconi/work/git/hs/geos/components.yaml"
hs = _get_all_repo_status(geos_mepo_components, verbose=True)
geos_mepo_components = "/home/fgdeconi/work/git/hs/geos/components.yaml"
hs2 = _get_all_repo_status(geos_mepo_components, verbose=True)
geos_mepo_components = "/home/fgdeconi/work/git/aq/geos/components.yaml"
aq = _get_all_repo_status(geos_mepo_components, verbose=True)
assert hs == hs2
assert hs == aq
30 changes: 15 additions & 15 deletions src/tcn/validation/serialbox/serialbox_dat_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@
import f90nml
import numpy as np

SERIALBOX_PYTHON = os.getenv("SERIALBOX_PYTHON", "")
if SERIALBOX_PYTHON == "":
raise RuntimeError(
"You must define env var SERIALBOX_PYTHON to point to root python package of serialbox."
)

sys.path.append(SERIALBOX_PYTHON)
import serialbox # noqa: E402


def get_parser():
parser = argparse.ArgumentParser("converts serialbox data to netcdf")
Expand Down Expand Up @@ -61,15 +52,15 @@ def read_serialized_data(serializer, savepoint, variable):
return data


def get_all_savepoint_names(data_path):
def get_all_savepoint_names(serialbox, data_path):
savepoint_names = set()
serializer = get_serializer(data_path, rank=0)
serializer = get_serializer(serialbox, data_path, rank=0)
for savepoint in serializer.savepoint_list():
savepoint_names.add(savepoint.name)
return savepoint_names


def get_serializer(data_path, rank):
def get_serializer(serialbox, data_path, rank):
return serialbox.Serializer(
serialbox.OpenModeKind.Read, data_path, "Generator_rank" + str(rank)
)
Expand All @@ -81,6 +72,15 @@ def main(
do_only_rank: int = -1,
do_only_savepoint: int = -1,
):
SERIALBOX_PYTHON = os.getenv("SERIALBOX_PYTHON", "")
if SERIALBOX_PYTHON == "":
raise RuntimeError(
"You must define env var SERIALBOX_PYTHON to point to root python package of serialbox."
)

sys.path.append(SERIALBOX_PYTHON)
import serialbox # noqa: E402

print("Make directory & read namelist... 🚧")
os.makedirs(output_path, exist_ok=True)
namelist_filename_in = os.path.join(data_path, "input.nml")
Expand All @@ -94,12 +94,12 @@ def main(
print("Done ✅")

print("Read savepoints... 🚧")
savepoint_names = get_all_savepoint_names(data_path)
savepoint_names = get_all_savepoint_names(serialbox, data_path)
print(f"Read {savepoint_names}... ✅")
for savepoint_name in sorted(list(savepoint_names)):
rank_list = []
# all ranks have the same names, just look at first one
serializer = get_serializer(data_path, rank=0)
serializer = get_serializer(serialbox, data_path, rank=0)
names_list = list(
serializer.fields_at_savepoint(serializer.get_savepoint(savepoint_name)[0])
)
Expand All @@ -111,7 +111,7 @@ def main(
if do_only_rank >= 0:
rank = do_only_rank
print(f"/ / Processing rank {rank}")
serializer = get_serializer(data_path, rank)
serializer = get_serializer(serialbox, data_path, rank)
serializer_list.append(serializer)
savepoints = serializer.get_savepoint(savepoint_name)
rank_data = {}
Expand Down

0 comments on commit e703d6d

Please sign in to comment.