diff --git a/backend/src/v2/cmd/driver/execution_paths.go b/backend/src/v2/cmd/driver/execution_paths.go new file mode 100644 index 00000000000..584d29065d5 --- /dev/null +++ b/backend/src/v2/cmd/driver/execution_paths.go @@ -0,0 +1,9 @@ +package main + +type ExecutionPaths struct { + ExecutionID string + IterationCount string + CachedDecision string + Condition string + PodSpecPatch string +} diff --git a/backend/src/v2/cmd/driver/main.go b/backend/src/v2/cmd/driver/main.go index 588d211521b..793ccfe1b80 100644 --- a/backend/src/v2/cmd/driver/main.go +++ b/backend/src/v2/cmd/driver/main.go @@ -37,6 +37,9 @@ import ( const ( driverTypeArg = "type" + ROOT_DAG = "ROOT_DAG" + DAG = "DAG" + CONTAINER = "CONTAINER" ) var ( @@ -160,12 +163,12 @@ func drive() (err error) { var execution *driver.Execution var driverErr error switch *driverType { - case "ROOT_DAG": + case ROOT_DAG: options.RuntimeConfig = runtimeConfig execution, driverErr = driver.RootDAG(ctx, options, client) - case "DAG": + case DAG: execution, driverErr = driver.DAG(ctx, options, client) - case "CONTAINER": + case CONTAINER: options.Container = containerSpec options.KubernetesExecutorConfig = k8sExecCfg execution, driverErr = driver.Container(ctx, options, client, cacheClient) @@ -183,35 +186,60 @@ func drive() (err error) { err = driverErr }() } + + executionPaths := &ExecutionPaths{ + ExecutionID: *executionIDPath, + IterationCount: *iterationCountPath, + CachedDecision: *cachedDecisionPath, + Condition: *conditionPath, + PodSpecPatch: *podSpecPatchPath} + + return handleExecution(execution, *driverType, executionPaths) +} + +func handleExecution(execution *driver.Execution, driverType string, executionPaths *ExecutionPaths) error { if execution.ID != 0 { glog.Infof("output execution.ID=%v", execution.ID) - if *executionIDPath != "" { - if err = writeFile(*executionIDPath, []byte(fmt.Sprint(execution.ID))); err != nil { + if executionPaths.ExecutionID != "" { + if err := writeFile(executionPaths.ExecutionID, []byte(fmt.Sprint(execution.ID))); err != nil { return fmt.Errorf("failed to write execution ID to file: %w", err) } } } if execution.IterationCount != nil { - if err = writeFile(*iterationCountPath, []byte(fmt.Sprintf("%v", *execution.IterationCount))); err != nil { + if err := writeFile(executionPaths.IterationCount, []byte(fmt.Sprintf("%v", *execution.IterationCount))); err != nil { return fmt.Errorf("failed to write iteration count to file: %w", err) } + } else { + if driverType == ROOT_DAG { + if err := writeFile(executionPaths.IterationCount, []byte("0")); err != nil { + return fmt.Errorf("failed to write iteration count to file: %w", err) + } + } } if execution.Cached != nil { - if err = writeFile(*cachedDecisionPath, []byte(strconv.FormatBool(*execution.Cached))); err != nil { + if err := writeFile(executionPaths.CachedDecision, []byte(strconv.FormatBool(*execution.Cached))); err != nil { return fmt.Errorf("failed to write cached decision to file: %w", err) } } if execution.Condition != nil { - if err = writeFile(*conditionPath, []byte(strconv.FormatBool(*execution.Condition))); err != nil { + if err := writeFile(executionPaths.Condition, []byte(strconv.FormatBool(*execution.Condition))); err != nil { return fmt.Errorf("failed to write condition to file: %w", err) } + } else { + // nil is a valid value for Condition + if driverType == ROOT_DAG || driverType == CONTAINER { + if err := writeFile(executionPaths.Condition, []byte("nil")); err != nil { + return fmt.Errorf("failed to write condition to file: %w", err) + } + } } if execution.PodSpecPatch != "" { glog.Infof("output podSpecPatch=\n%s\n", execution.PodSpecPatch) - if *podSpecPatchPath == "" { + if executionPaths.PodSpecPatch == "" { return fmt.Errorf("--pod_spec_patch_path is required for container executor drivers") } - if err = writeFile(*podSpecPatchPath, []byte(execution.PodSpecPatch)); err != nil { + if err := writeFile(executionPaths.PodSpecPatch, []byte(execution.PodSpecPatch)); err != nil { return fmt.Errorf("failed to write pod spec patch to file: %w", err) } } diff --git a/backend/src/v2/cmd/driver/main_test.go b/backend/src/v2/cmd/driver/main_test.go new file mode 100644 index 00000000000..abaea81a804 --- /dev/null +++ b/backend/src/v2/cmd/driver/main_test.go @@ -0,0 +1,79 @@ +package main + +import ( + "github.com/kubeflow/pipelines/backend/src/v2/driver" + "os" + "testing" +) + +func Test_handleExecutionContainer(t *testing.T) { + execution := &driver.Execution{} + + executionPaths := &ExecutionPaths{ + Condition: "condition.txt", + } + + err := handleExecution(execution, CONTAINER, executionPaths) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + verifyFileContent(t, executionPaths.Condition, "nil") + + cleanup(t, executionPaths) +} + +func Test_handleExecutionRootDAG(t *testing.T) { + execution := &driver.Execution{} + + executionPaths := &ExecutionPaths{ + IterationCount: "iteration_count.txt", + Condition: "condition.txt", + } + + err := handleExecution(execution, ROOT_DAG, executionPaths) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + verifyFileContent(t, executionPaths.IterationCount, "0") + verifyFileContent(t, executionPaths.Condition, "nil") + + cleanup(t, executionPaths) +} + +func cleanup(t *testing.T, executionPaths *ExecutionPaths) { + removeIfExists(t, executionPaths.IterationCount) + removeIfExists(t, executionPaths.ExecutionID) + removeIfExists(t, executionPaths.Condition) + removeIfExists(t, executionPaths.PodSpecPatch) + removeIfExists(t, executionPaths.CachedDecision) +} + +func removeIfExists(t *testing.T, filePath string) { + _, err := os.Stat(filePath) + if err == nil { + err = os.Remove(filePath) + if err != nil { + t.Errorf("Unexpected error while removing the created file: %v", err) + } + } +} + +func verifyFileContent(t *testing.T, filePath string, expectedContent string) { + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + t.Errorf("Expected file %s to be created, but it doesn't exist", filePath) + } + + fileContent, err := os.ReadFile(filePath) + if err != nil { + t.Errorf("Failed to read file contents: %v", err) + } + + if string(fileContent) != expectedContent { + t.Errorf("Expected file fileContent to be %q, got %q", expectedContent, string(fileContent)) + } +}