Skip to content

Commit

Permalink
Merge pull request #30 from DARPA-ASKEM/kbirk/taskrunner-progress
Browse files Browse the repository at this point in the history
Add progress stuff
  • Loading branch information
kbirk committed Apr 19, 2024
2 parents 62e7461 + 8db4c8e commit 6f52ee4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
41 changes: 41 additions & 0 deletions core/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def __init__(self, description: str):
required=False,
help="The name of the output pipe",
)
parser.add_argument(
"--progress_pipe",
type=str,
required=False,
help="The name of the progress pipe",
)
parser.add_argument(
"--self_destruct_timeout_seconds",
type=int,
Expand All @@ -64,6 +70,8 @@ def __init__(self, description: str):
self.input = args.input
self.input_pipe = args.input_pipe
self.output_pipe = args.output_pipe
self.progress_pipe = args.progress_pipe
self.has_written_output = False

if self.input is None and self.input_pipe is None:
raise ValueError("Either `input` or `input_pipe` must be specified")
Expand Down Expand Up @@ -101,6 +109,30 @@ def read_input() -> dict:
except concurrent.futures.TimeoutError:
raise TimeoutError("Reading from input pipe timed out")

def write_progress_with_timeout(self, progress: dict, timeout_seconds: int):
def write_progress(progress_pipe: str, progress: dict):
bs = json.dumps(progress, separators=(",", ":")).encode()
with open(progress_pipe, "wb") as f_out:
f_out.write(bs)
return

# if no progress pipe is specified, just print the progress to stdout
if self.progress_pipe is None:
self.log("Writing progress to stdout")
print(json.dumps(progress))
return

with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(write_progress, self.progress_pipe, progress)
try:
return future.result(timeout=timeout_seconds)
except concurrent.futures.TimeoutError:
print(
"Writing to progress pipe {} timed out".format(self.progress_pipe),
flush=True,
)
raise TimeoutError("Writing to output pipe timed out")

def write_output_with_timeout(self, output: dict, timeout_seconds: int = 30):
def write_output(output: dict):
self.log("Writing output to output pipe")
Expand All @@ -109,12 +141,21 @@ def write_output(output: dict):
f_out.write(bs)
return

# output should only be written once
if self.has_written_output:
raise ValueError("Output has already been written")

self.has_written_output = True

# if no output pipe is specified, just print the output to stdout
if self.output_pipe is None:
self.log("Writing output to stdout")
print(json.dumps(output))
return

# signal to the taskrunner that it should stop consuming progress
self.write_progress_with_timeout(self, {"done": True}, timeout_seconds)

# otherwise use the output pipe
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(write_output, output)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
"gollm:configure_model=tasks.configure_model:main",
"gollm:model_card=tasks.model_card:main",
"gollm:embedding=tasks.embedding:main",
"gollm:compare_models=tasks.compare_models:main",
"gollm:dataset_configure=tasks.dataset_configure:main",
"gollm:compare_models=tasks.compare_models:main",
"gollm:dataset_configure=tasks.dataset_configure:main",
],
},
python_requires=">=3.8",
Expand Down

0 comments on commit 6f52ee4

Please sign in to comment.