Skip to content

Commit

Permalink
Merge pull request #19 from dougiesquire/issue7-add-accessom3-model
Browse files Browse the repository at this point in the history
Add AccessOm3 model class
  • Loading branch information
dougiesquire committed Jun 5, 2024
2 parents 809122d + 213dd71 commit 8f5a421
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 82 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ dependencies = [
"requests",
"pytest >=8.0.1",
"ruamel.yaml >=0.18.5",
"jsonschema >=4.21.1"
"jsonschema >=4.21.1",
"payu >=1.1.3"
]

[project.optional-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion src/model_config_tests/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from model_config_tests.models.accessom2 import AccessOm2
from model_config_tests.models.accessom3 import AccessOm3

index = {"access-om2": AccessOm2}
index = {"access-om2": AccessOm2, "access-om3": AccessOm3}
45 changes: 9 additions & 36 deletions src/model_config_tests/models/accessom2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@

import f90nml

from model_config_tests.models.model import Model

BASE_SCHEMA_URL = "https://raw.githubusercontent.com/ACCESS-NRI/schema/main/au.org.access-nri/model/access-om2/experiment/reproducibility/checksums"

SCHEMA_VERSION_1_0_0 = "1-0-0"
DEFAULT_SCHEMA_VERSION = SCHEMA_VERSION_1_0_0
SUPPORTED_SCHEMA_VERSIONS = [SCHEMA_VERSION_1_0_0]
from model_config_tests.models.model import SCHEMA_VERSION_1_0_0, Model


class AccessOm2(Model):
Expand All @@ -23,14 +17,20 @@ def __init__(self, experiment):

self.accessom2_config = experiment.control_path / "accessom2.nml"
self.ocean_config = experiment.control_path / "ocean" / "input.nml"
self.default_schema_version = DEFAULT_SCHEMA_VERSION

def set_model_runtime(self, years: int = 0, months: int = 0, seconds: int = 10800):
"""Set config files to a short time period for experiment run.
Default is 3 hours"""
with open(self.accessom2_config) as f:
nml = f90nml.read(f)

# Check that two of years, months, seconds is zero
if sum(x == 0 for x in (years, months, seconds)) != 2:
raise NotImplementedError(
"Cannot specify runtime in seconds and years and months"
+ " at the same time. Two of which must be zero"
)

nml["date_manager_nml"]["restart_period"] = [years, months, seconds]
nml.write(self.accessom2_config, force=True)

Expand Down Expand Up @@ -75,7 +75,7 @@ def extract_checksums(
output_checksums[field].append(checksum)

if schema_version is None:
schema_version = DEFAULT_SCHEMA_VERSION
schema_version = self.default_schema_version

if schema_version == SCHEMA_VERSION_1_0_0:
checksums = {
Expand All @@ -88,30 +88,3 @@ def extract_checksums(
)

return checksums

def check_checksums_over_restarts(
self,
long_run_checksum: dict[str, Any],
short_run_checksum_0: dict[str, Any],
short_run_checksum_1: dict[str, Any],
) -> bool:
"""Compare a checksums from a long run (e.g. 2 days) against
checksums from 2 short runs (e.g. 1 day)"""
short_run_checksums = short_run_checksum_0["output"]
for field, checksums in short_run_checksum_1["output"].items():
if field not in short_run_checksums:
short_run_checksums[field] = checksums
else:
short_run_checksums[field].extend(checksums)

matching_checksums = True
for field, checksums in long_run_checksum["output"].items():
for checksum in checksums:
if (
field not in short_run_checksums
or checksum not in short_run_checksums[field]
):
print(f"Unequal checksum: {field}: {checksum}")
matching_checksums = False

return matching_checksums
90 changes: 90 additions & 0 deletions src/model_config_tests/models/accessom3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Specific Access-OM3 Model setup and post-processing"""

import re
from collections import defaultdict
from pathlib import Path
from typing import Any

from payu.models.cesm_cmeps import Runconfig

from model_config_tests.models.model import SCHEMA_VERSION_1_0_0, Model


class AccessOm3(Model):
def __init__(self, experiment):
super().__init__(experiment)
self.output_file = self.experiment.output000 / "ocean.stats"

self.runconfig = experiment.control_path / "nuopc.runconfig"
self.ocean_config = experiment.control_path / "input.nml"

def set_model_runtime(self, years: int = 0, months: int = 0, seconds: int = 10800):
"""Set config files to a short time period for experiment run.
Default is 3 hours"""
runconfig = Runconfig(self.runconfig)

if years == months == 0:
freq = "nseconds"
n = str(seconds)
elif seconds == 0:
freq = "nmonths"
n = str(12 * years + months)
else:
raise NotImplementedError(
"Cannot specify runtime in seconds and year/months at the same time"
)

runconfig.set("CLOCK_attributes", "restart_n", n)
runconfig.set("CLOCK_attributes", "restart_option", freq)
runconfig.set("CLOCK_attributes", "stop_n", n)
runconfig.set("CLOCK_attributes", "stop_option", freq)

runconfig.write()

def output_exists(self) -> bool:
"""Check for existing output file"""
return self.output_file.exists()

def extract_checksums(
self, output_directory: Path = None, schema_version: str = None
) -> dict[str, Any]:
"""Parse output file and create checksum using defined schema"""
if output_directory:
output_filename = output_directory / "ocean.stats"
else:
output_filename = self.output_file

# ocean.stats is used for regression testing in MOM6's own test suite
# See https://github.com/mom-ocean/MOM6/blob/2ab885eddfc47fc0c8c0bae46bc61531104428d5/.testing/Makefile#L495-L501
# Rows in ocean.stats look like:
# 0, 693135.000, 0, En 3.0745627134675957E-23, CFL 0.00000, ...
# where the first three columns are Step, Day, Truncs and the remaining
# columns include a label for what they are (e.g. En = Energy/Mass)
# Header info is only included for new runs so can't be relied on
output_checksums: dict[str, list[any]] = defaultdict(list)

with open(output_filename) as f:
lines = f.readlines()
# Skip header if it exists (for new runs)
istart = 2 if "Step" in lines[0] else 0
for line in lines[istart:]:
for col in line.split(","):
# Only keep columns with labels (ie not Step, Day, Truncs)
col = re.split(" +", col.strip().rstrip("\n"))
if len(col) > 1:
output_checksums[col[0]].append(col[-1])

if schema_version is None:
schema_version = self.default_schema_version

if schema_version == SCHEMA_VERSION_1_0_0:
checksums = {
"schema_version": schema_version,
"output": dict(output_checksums),
}
else:
raise NotImplementedError(
f"Unsupported checksum schema version: {schema_version}"
)

return checksums
27 changes: 26 additions & 1 deletion src/model_config_tests/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@

from pathlib import Path

SCHEMA_VERSION_1_0_0 = "1-0-0"
SCHEMA_1_0_0_URL = "https://raw.githubusercontent.com/ACCESS-NRI/schema/7666d95967de4dfd19b0d271f167fdcfd3f46962/au.org.access-nri/model/reproducibility/checksums/1-0-0.json"
SCHEMA_VERSION_TO_URL = {SCHEMA_VERSION_1_0_0: SCHEMA_1_0_0_URL}
DEFAULT_SCHEMA_VERSION = "1-0-0"


class Model:
def __init__(self, experiment):
self.experiment = experiment

self.default_schema_version = DEFAULT_SCHEMA_VERSION
self.schema_version_to_url = SCHEMA_VERSION_TO_URL

def extract_checksums(self, output_directory: Path, schema_version: str):
"""Extract checksums from output directory"""
raise NotImplementedError
Expand All @@ -24,4 +32,21 @@ def check_checksums_over_restarts(
) -> bool:
"""Compare a checksums from a long run (e.g. 2 days) against
checksums from 2 short runs (e.g. 1 day)"""
raise NotImplementedError
short_run_checksums = short_run_checksum_0["output"]
for field, checksums in short_run_checksum_1["output"].items():
if field not in short_run_checksums:
short_run_checksums[field] = checksums
else:
short_run_checksums[field].extend(checksums)

matching_checksums = True
for field, checksums in long_run_checksum["output"].items():
for checksum in checksums:
if (
field not in short_run_checksums
or checksum not in short_run_checksums[field]
):
print(f"Unequal checksum: {field}: {checksum}")
matching_checksums = False

return matching_checksums
File renamed without changes.
32 changes: 32 additions & 0 deletions tests/resources/access-om3/checksums/1-0-0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"schema_version": "1-0-0",
"output": {
"En": [
"3.0745627134675957E-23"
],
"CFL": [
"0.00000"
],
"SL": [
"1.5112E-10"
],
"M": [
"1.36404E+21"
],
"S": [
"34.7263"
],
"T": [
"3.6362"
],
"Me": [
"0.00E+00"
],
"Se": [
"0.00E+00"
],
"Te": [
"0.00E+00"
]
}
}
3 changes: 3 additions & 0 deletions tests/resources/access-om3/output000/ocean.stats
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Step, Day, Truncs, Energy/Mass, Maximum CFL, Mean Sea Level, Total Mass, Mean Salin, Mean Temp, Frac Mass Err, Salin Err, Temp Err
[days] [m2 s-2] [Nondim] [m] [kg] [PSU] [degC] [Nondim] [PSU] [degC]
0, 693135.000, 0, En 3.0745627134675957E-23, CFL 0.00000, SL 1.5112E-10, M 1.36404E+21, S 34.7263, T 3.6362, Me 0.00E+00, Se 0.00E+00, Te 0.00E+00
43 changes: 0 additions & 43 deletions tests/test_access_om2_extract_checksums.py

This file was deleted.

73 changes: 73 additions & 0 deletions tests/test_model_extract_checksums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
import os
from pathlib import Path
from unittest.mock import Mock

import jsonschema
import pytest
import requests

from model_config_tests.models import index as model_index

MODEL_NAMES = model_index.keys()
HERE = os.path.dirname(__file__)
RESOURCES_DIR = Path(f"{HERE}/resources")


@pytest.mark.parametrize("model_name", MODEL_NAMES)
def test_extract_checksums(model_name):
resources_dir = RESOURCES_DIR / model_name

# Mock ExpTestHelper
mock_experiment = Mock()
mock_experiment.output000 = resources_dir / "output000"
mock_experiment.control_path = Path("test/tmp")

# Create Model instance
ModelType = model_index[model_name]
model = ModelType(mock_experiment)

# Test extract checksums for each schema version
for version, url in model.schema_version_to_url.items():
checksums = model.extract_checksums(schema_version=version)

# Assert version is set as expected
assert checksums["schema_version"] == version

# Check the entire checksum file is expected
checksum_file = resources_dir / "checksums" / f"{version}.json"
with open(checksum_file) as file:
expected_checksums = json.load(file)

assert checksums == expected_checksums

# Validate checksum file with schema
schema = get_schema_from_url(url)

# Validate checksums against schema
jsonschema.validate(instance=checksums, schema=schema)


@pytest.mark.parametrize("model_name", MODEL_NAMES)
def test_extract_checksums_unsupported_version(model_name):
resources_dir = RESOURCES_DIR / model_name

# Mock ExpTestHelper
mock_experiment = Mock()
mock_experiment.output000 = resources_dir / "output000"
mock_experiment.control_path = Path("test/tmp")

# Create Model instance
ModelType = model_index[model_name]
model = ModelType(mock_experiment)

# Test NotImplementedError gets raised for unsupported versions
with pytest.raises(NotImplementedError):
model.extract_checksums(schema_version="test-version")


def get_schema_from_url(url):
"""Retrieve schema from GitHub"""
response = requests.get(url)
assert response.status_code == 200
return response.json()

0 comments on commit 8f5a421

Please sign in to comment.