diff --git a/payu/models/mom6.py b/payu/models/mom6.py index 86943975..e4a48d12 100644 --- a/payu/models/mom6.py +++ b/payu/models/mom6.py @@ -10,6 +10,7 @@ # Standard library import os +import shutil # Extensions import f90nml @@ -31,27 +32,12 @@ def __init__(self, expt, name, config): self.config_files = [ 'input.nml', - 'MOM_input', 'diag_table', ] - # TODO: Need to figure out what's going on here with MOM6 self.optional_config_files = [ 'data_table', - 'data_table.MOM6', - 'data_table.OM4', - 'data_table.SIS', - 'data_table.icebergs', - - 'field_table', - - 'MOM_override', - 'MOM_layout', - 'MOM_saltrestore', - - 'SIS_input', - 'SIS_override', - 'SIS_layout', + 'field_table' ] def setup(self): @@ -59,6 +45,7 @@ def setup(self): super(Mom6, self).setup() self.init_config() + self.add_config_files() def init_config(self): """Patch input.nml as a new or restart run.""" @@ -78,3 +65,47 @@ def init_config(self): input_nml['SIS_input_nml']['input_filename'] = input_type f90nml.write(input_nml, input_fpath, force=True) + + def add_config_files(self): + """Add to model configuration files""" + + # Add parameter config files + config_files_to_add = self.get_parameter_files() + + # Set of all configuration files + all_config_files = set(self.config_files).union( + self.optional_config_files) + + for filename in config_files_to_add: + if filename not in all_config_files: + # Extend config files + self.config_files.append(filename) + all_config_files.add(filename) + + # Copy file from control path to work path + file_path = os.path.join(self.control_path, filename) + shutil.copy(file_path, self.work_path) + + def get_parameter_files(self): + """Return a list of parameter config files defined in input.nml""" + input_nml = f90nml.read(os.path.join(self.work_path, 'input.nml')) + + input_namelists = ['MOM_input_nml'] + if 'SIS_input_nml' in input_nml: + input_namelists.append('SIS_input_nml') + + parameter_files = [] + for input in input_namelists: + input_namelist = input_nml.get(input, {}) + filenames = input_namelist.get('parameter_filename', []) + + if filenames == []: + print("payu: warning: MOM6: There are no parameter files " + f"listed under {input} in input.nml") + + if isinstance(filenames, str): + parameter_files.append(filenames) + else: + parameter_files.extend(filenames) + + return parameter_files diff --git a/test/common.py b/test/common.py index 639bb16a..867a7938 100644 --- a/test/common.py +++ b/test/common.py @@ -13,13 +13,17 @@ from payu.subcommands.setup_cmd import runcmd as payu_setup_orignal from payu.subcommands.sweep_cmd import runcmd as payu_sweep +ctrldir_basename = 'ctrl' + testdir = Path().cwd() / Path('test') tmpdir = testdir / 'tmp' -ctrldir = tmpdir / 'ctrl' +ctrldir = tmpdir / ctrldir_basename labdir = tmpdir / 'lab' workdir = ctrldir / 'work' payudir = tmpdir / 'payu' +expt_workdir = labdir / 'work' / ctrldir_basename + print('tmpdir: {}'.format(tmpdir)) config = { @@ -43,6 +47,7 @@ } } + @contextmanager def cd(directory): """ diff --git a/test/models/test_mom6.py b/test/models/test_mom6.py new file mode 100644 index 00000000..64b1f923 --- /dev/null +++ b/test/models/test_mom6.py @@ -0,0 +1,134 @@ +import copy +import os +import shutil + +import pytest +import f90nml + +import payu + +from test.common import cd +from test.common import tmpdir, ctrldir, labdir, expt_workdir +from test.common import config as config_orig +from test.common import write_config +from test.common import make_random_file, make_inputs + +verbose = True + +# Global config +config = copy.deepcopy(config_orig) +config["model"] = "mom6" + + +def setup_module(module): + """ + Put any test-wide setup code in here, e.g. creating test files + """ + if verbose: + print("setup_module module:%s" % module.__name__) + + # Should be taken care of by teardown, in case remnants lying around + try: + shutil.rmtree(tmpdir) + except FileNotFoundError: + pass + + try: + tmpdir.mkdir() + labdir.mkdir() + ctrldir.mkdir() + expt_workdir.mkdir(parents=True) + make_inputs() + except Exception as e: + print(e) + + write_config(config) + + +def teardown_module(module): + """ + Put any test-wide teardown code in here, e.g. removing test outputs + """ + if verbose: + print("teardown_module module:%s" % module.__name__) + + try: + shutil.rmtree(tmpdir) + print('removing tmp') + except Exception as e: + print(e) + + +@pytest.fixture(autouse=True) +def teardown(): + # Run test + yield + + # Remove any files in expt work directory + for file in os.listdir(expt_workdir): + try: + os.remove(os.path.join(expt_workdir, file)) + except Exception as e: + print(e) + + +@pytest.mark.parametrize( + "input_nml, expected_files_added", + [ + ( + { + "MOM_input_nml": { + "parameter_filename": "MOM_Input" + } + }, + ["MOM_Input"] + ), + ( + { + "SIS_input_nml": { + "parameter_filename": "SIS_Input" + } + }, + ["SIS_Input"] + ), + ( + { + "MOM_input_nml": { + "parameter_filename": ["MOM_Input", "MOM_override"] + }, + "SIS_input_nml": { + "output_directory": '.' + } + }, + ["MOM_Input", "MOM_override"] + ) + ]) +def test_add_config_files(input_nml, + expected_files_added): + # Create config files in control directory + for file in expected_files_added: + filename = os.path.join(ctrldir, file) + make_random_file(filename, 8) + + # Create config.nml + input_nml_fp = os.path.join(expt_workdir, 'input.nml') + f90nml.write(input_nml, input_nml_fp) + + with cd(ctrldir): + lab = payu.laboratory.Laboratory(lab_path=str(labdir)) + expt = payu.experiment.Experiment(lab, reproduce=False) + model = expt.models[0] + + prior_config_files = model.config_files[:] + + # Function to test + model.add_config_files() + + # Check files are added to config_files + added_files = set(model.config_files).difference(prior_config_files) + assert added_files == set(expected_files_added) + + # Check the extra files are moved to model's work path + ctrl_path_files = os.listdir(model.work_path) + for file in expected_files_added: + assert file in ctrl_path_files