Skip to content

Commit

Permalink
Move conditional expand logic to base class
Browse files Browse the repository at this point in the history
Also add a gpu-per-node entry
  • Loading branch information
linsword13 committed Dec 17, 2024
1 parent 716811c commit bf8ae08
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
22 changes: 22 additions & 0 deletions lib/ramble/ramble/workflow_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ramble.util.naming import NS_SEPARATOR
import ramble.util.class_attributes
import ramble.util.directives
from ramble.expander import ExpanderError


class WorkflowManagerBase(metaclass=WorkflowManagerMeta):
Expand Down Expand Up @@ -48,6 +49,27 @@ def get_status(self, workspace):
"""Return status of a given job"""
return None

def conditional_expand(self, templates):
"""Return a (potentially empty) list of expanded strings
Args:
templates: A list of templates to expand.
If the template cannot be fully expanded, it's skipped.
Returns:
A list of expanded strings
"""
expander = self.app_inst.expander
expanded = []
for tpl in templates:
try:
rendered = expander.expand_var(tpl, allow_passthrough=False)
if rendered:
expanded.append(rendered)
except ExpanderError:
# Skip a particular entry if any of the vars are not defined
continue
return expanded

def copy(self):
"""Deep copy a workflow manager instance"""
new_copy = type(self)(self._file_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ def __init__(self, file_path):

self.runner = SlurmRunner()

workflow_manager_variable(
name="partition",
default="",
description="Name of the slurm partition for job submission",
)

workflow_manager_variable(
name="job_name",
default="{application_name}_{workload_name}_{experiment_name}",
Expand Down Expand Up @@ -96,31 +90,21 @@ def _slurm_execute_script(self):
expander = self.app_inst.expander
# Adding pre-defined sets of headers
pragmas = [
("#SBATCH -N {}", "n_nodes"),
("#SBATCH -p {}", "partition"),
("#SBATCH --ntasks-per-node {}", "processes_per_node"),
("#SBATCH -J {}", "job_name"),
("#SBATCH -N {n_nodes}"),
("#SBATCH -p {partition}"),
("#SBATCH --ntasks-per-node {processes_per_node}"),
("#SBATCH -J {job_name}"),
("#SBATCH --gpus-per-node {gpus_per_node}"),
]
for tpl, var in pragmas:
try:
val = expander.expand_var_name(var, allow_passthrough=False)
except ExpanderError:
# Skip a particular header if any of the vars are not defined
continue
if val:
headers.append(tpl.format(val))
# Adding extra arbitrary headers
try:
extra_sbatch_headers_raw = expander.expand_var_name(
"extra_sbatch_headers", allow_passthrough=False
)
extra_sbatch_headers = extra_sbatch_headers_raw.strip().split("\n")
extra_headers = [
expander.expand_var(h) for h in extra_sbatch_headers
]
headers = headers + extra_headers
pragmas = pragmas + extra_sbatch_headers
except ExpanderError:
pass
headers = headers + self.conditional_expand(pragmas)
header_str = "\n".join(headers)
content = rf"""
{header_str}
Expand Down

0 comments on commit bf8ae08

Please sign in to comment.