Skip to content

Commit

Permalink
refactor list_archive_dirs function out of experiment to fsops
Browse files Browse the repository at this point in the history
  • Loading branch information
Jo Basevi committed Dec 21, 2023
1 parent 47c36f9 commit 6e59c50
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 43 deletions.
35 changes: 11 additions & 24 deletions payu/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ruamel.yaml import YAML, CommentedMap
import git

from payu.fsops import read_config, DEFAULT_CONFIG_FNAME
from payu.fsops import read_config, DEFAULT_CONFIG_FNAME, list_archive_dirs
from payu.laboratory import Laboratory
from payu.metadata import Metadata, UUID_FIELD, METADATA_FILENAME
from payu.git_utils import GitRepository, git_clone
Expand All @@ -36,25 +36,10 @@
Where BRANCH_NAME is the name of the branch"""


def archive_contains_restarts(archive_path: Path) -> bool:
"""Return True if there's pre-existing restarts in archive"""
pattern = re.compile(r"^restart[0-9][0-9][0-9]+$")
if not archive_path.exists():
return False

for path in archive_path.iterdir():
real_path = path.resolve()
if real_path.is_dir() and pattern.match(path.name):
return True
return False


def check_restart(restart_path: Optional[Path],
archive_path: Path) -> Optional[Path]:
"""Checks for valid prior restart path. Returns resolved restart path
if valid, otherwise returns None"""
if restart_path is None:
return

# Check for valid path
if not restart_path.exists():
Expand All @@ -66,11 +51,12 @@ def check_restart(restart_path: Optional[Path],
restart_path = restart_path.resolve()

# Check for pre-existing restarts in archive
if archive_contains_restarts(archive_path):
warnings.warn((
f"Pre-existing restarts found in archive: {archive_path}."
f"Skipping adding 'restart: {restart_path}' to config file"))
return
if archive_path.exists():
if len(list_archive_dirs(archive_path, dir_type="restart")) > 0:
warnings.warn((
f"Pre-existing restarts found in archive: {archive_path}."
f"Skipping adding 'restart: {restart_path}' to config file"))
return

return restart_path

Expand Down Expand Up @@ -171,8 +157,10 @@ def checkout_branch(branch_name: str,
is_new_experiment=is_new_experiment)

# Gets valid prior restart path
prior_restart_path = check_restart(restart_path=restart_path,
archive_path=metadata.archive_path)
prior_restart_path = None
if restart_path:
prior_restart_path = check_restart(restart_path=restart_path,
archive_path=metadata.archive_path)

# Create/update and commit metadata file
metadata.write_metadata(set_template_values=True,
Expand Down Expand Up @@ -288,7 +276,6 @@ def clone(repository: str,

def get_branch_metadata(branch: git.Head) -> Optional[CommentedMap]:
"""Return dictionary of branch metadata if it exists, None otherwise"""
# Note: Blobs are files in the commit tree
for blob in branch.commit.tree.blobs:
if blob.name == METADATA_FILENAME:
# Read file contents
Expand Down
18 changes: 5 additions & 13 deletions payu/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# Local
from payu import envmod
from payu.fsops import mkdir_p, make_symlink, read_config, movetree
from payu.fsops import list_archive_dirs
from payu.schedulers.pbs import get_job_info, pbs_env_init, get_job_id
from payu.models import index as model_index
import payu.profilers
Expand Down Expand Up @@ -198,7 +199,8 @@ def max_output_index(self, output_type="output"):
"""Given a output directory type (output or restart),
return the maximum index of output directories found"""
try:
output_dirs = self.list_output_dirs(output_type)
output_dirs = list_archive_dirs(archive_path=self.archive_path,
dir_type=output_type)
except EnvironmentError as exc:
if exc.errno == errno.ENOENT:
output_dirs = None
Expand All @@ -208,17 +210,6 @@ def max_output_index(self, output_type="output"):
if output_dirs and len(output_dirs):
return int(output_dirs[-1].lstrip(output_type))

def list_output_dirs(self, output_type="output", full_path=False):
"""Return a sorted list of restart or output directories in archive"""
naming_pattern = re.compile(fr"^{output_type}[0-9][0-9][0-9]+$")
dirs = [d for d in os.listdir(self.archive_path)
if naming_pattern.match(d)]
dirs.sort(key=lambda d: int(d.lstrip(output_type)))

if full_path:
dirs = [os.path.join(self.archive_path, d) for d in dirs]
return dirs

def set_stacksize(self, stacksize):

if stacksize == 'unlimited':
Expand Down Expand Up @@ -972,7 +963,8 @@ def get_restarts_to_prune(self,
return []

# List all restart directories in archive
restarts = self.list_output_dirs(output_type='restart')
restarts = list_archive_dirs(archive_path=self.archive_path,
dir_type='restart')

# TODO: Previous logic was to prune all restarts if self.repeat_run
# Still need to figure out what should happen in this case
Expand Down
15 changes: 14 additions & 1 deletion payu/fsops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Standard library
import errno
import os
import re
import shutil
import sys
import shlex
Expand Down Expand Up @@ -207,4 +208,16 @@ def required_libs(bin_path):
except:
print("payu: error running ldd command on exe path: ", bin_path)
return {}
return parse_ldd_output(ldd_out)
return parse_ldd_output(ldd_out)


def list_archive_dirs(archive_path, dir_type="output", full_path=False):
"""Return a sorted list of restart or output directories in archive"""
naming_pattern = re.compile(fr"^{dir_type}[0-9][0-9][0-9]+$")
dirs = [d for d in os.listdir(archive_path)
if naming_pattern.match(d)]
dirs.sort(key=lambda d: int(d.lstrip(dir_type)))

if full_path:
dirs = [os.path.join(archive_path, d) for d in dirs]
return dirs
12 changes: 7 additions & 5 deletions payu/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


# Local
from payu.fsops import mkdir_p
from payu.fsops import mkdir_p, list_archive_dirs
from payu.metadata import METADATA_FILENAME


Expand Down Expand Up @@ -49,8 +49,9 @@ def __init__(self, expt):
def add_outputs_to_sync(self):
"""Add paths of outputs in archive to sync. The last output is
protected"""
outputs = self.expt.list_output_dirs(output_type='output',
full_path=True)
outputs = list_archive_dirs(archive_path=self.expt.archive_path,
dir_type='output',
full_path=True)
if len(outputs) > 0:
last_output = outputs.pop()
if not self.ignore_last:
Expand All @@ -70,8 +71,9 @@ def add_restarts_to_sync(self):
return

# Get sorted list of restarts in archive
restarts = self.expt.list_output_dirs(output_type='restart',
full_path=True)
restarts = list_archive_dirs(archive_path=self.expt.archive_path,
dir_type='restart',
full_path=True)
if restarts == []:
return

Expand Down

0 comments on commit 6e59c50

Please sign in to comment.