diff --git a/backend/src/apiserver/client_manager/client_manager.go b/backend/src/apiserver/client_manager/client_manager.go index e74310f3e960..1de94fa4c13f 100644 --- a/backend/src/apiserver/client_manager/client_manager.go +++ b/backend/src/apiserver/client_manager/client_manager.go @@ -208,7 +208,8 @@ func (c *ClientManager) init() { c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout), clientParams) - newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort()) + newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMLPipelineTLSEnabled()) + if err != nil { glog.Fatalf("Failed to create metadata client. Error: %v", err) } diff --git a/backend/src/apiserver/common/config.go b/backend/src/apiserver/common/config.go index c0763ed07778..ba258d6578e0 100644 --- a/backend/src/apiserver/common/config.go +++ b/backend/src/apiserver/common/config.go @@ -35,6 +35,7 @@ const ( MetadataGrpcServiceServiceHost string = "METADATA_GRPC_SERVICE_SERVICE_HOST" MetadataGrpcServiceServicePort string = "METADATA_GRPC_SERVICE_SERVICE_PORT" SignedURLExpiryTimeSeconds string = "SIGNED_URL_EXPIRY_TIME_SECONDS" + MLPipelineTLSEnabled string = "ML_PIPELINE_TLS_ENABLED" ) func IsPipelineVersionUpdatedByDefault() bool { @@ -142,3 +143,7 @@ func GetMetadataGrpcServiceServicePort() string { func GetSignedURLExpiryTimeSeconds() int { return GetIntConfigWithDefault(SignedURLExpiryTimeSeconds, DefaultSignedURLExpiryTimeSeconds) } + +func GetMLPipelineTLSEnabled() bool { + return GetBoolFromStringWithDefault(MLPipelineTLSEnabled, DefaultMLPipelineTLSEnabled) +} diff --git a/backend/src/apiserver/common/const.go b/backend/src/apiserver/common/const.go index 168675e294c0..59296dc63094 100644 --- a/backend/src/apiserver/common/const.go +++ b/backend/src/apiserver/common/const.go @@ -68,6 +68,8 @@ const ( const DefaultSignedURLExpiryTimeSeconds = 15 +const DefaultMLPipelineTLSEnabled = true + const ( MaxFileNameLength = 100 MaxFileLength = 32 << 20 // 32Mb diff --git a/backend/src/v2/cmd/driver/main.go b/backend/src/v2/cmd/driver/main.go index 98127c284460..1a78869d3010 100644 --- a/backend/src/v2/cmd/driver/main.go +++ b/backend/src/v2/cmd/driver/main.go @@ -167,6 +167,8 @@ func drive() (err error) { DAGExecutionID: *dagExecutionID, IterationIndex: *iterationIndex, MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled, + MLMDServerAddress: *mlmdServerAddress, + MLMDServerPort: *mlmdServerPort, } var execution *driver.Execution var driverErr error @@ -292,5 +294,11 @@ func newMlmdClient() (*metadata.Client, error) { mlmdConfig.Address = *mlmdServerAddress mlmdConfig.Port = *mlmdServerPort } - return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port) + + mlPipelineServiceTLSEnabled, err := strconv.ParseBool(*mlPipelineServiceTLSEnabledStr) + if err != nil { + return nil, err + } + + return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, mlPipelineServiceTLSEnabled) } diff --git a/backend/src/v2/compiler/argocompiler/container.go b/backend/src/v2/compiler/argocompiler/container.go index 7b12ca174d1b..636e3bbd1b26 100644 --- a/backend/src/v2/compiler/argocompiler/container.go +++ b/backend/src/v2/compiler/argocompiler/container.go @@ -17,6 +17,7 @@ package argocompiler import ( wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/v2/component" k8score "k8s.io/api/core/v1" "os" @@ -163,6 +164,8 @@ func (c *workflowCompiler) addContainerDriverTemplate() string { "--condition_path", outputPath(paramCondition), "--kubernetes_config", inputValue(paramKubernetesConfig), "--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled), + "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), + "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), }, Resources: driverResources, }, diff --git a/backend/src/v2/compiler/argocompiler/dag.go b/backend/src/v2/compiler/argocompiler/dag.go index 36a239667e3b..d42246577bda 100644 --- a/backend/src/v2/compiler/argocompiler/dag.go +++ b/backend/src/v2/compiler/argocompiler/dag.go @@ -15,6 +15,7 @@ package argocompiler import ( "fmt" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" "sort" "strconv" "strings" @@ -443,6 +444,8 @@ func (c *workflowCompiler) addDAGDriverTemplate() string { "--iteration_count_path", outputPath(paramIterationCount), "--condition_path", outputPath(paramCondition), "--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled), + "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), + "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), }, Resources: driverResources, }, diff --git a/backend/src/v2/compiler/argocompiler/importer.go b/backend/src/v2/compiler/argocompiler/importer.go index 83ac6453b64e..e84c2d673b1a 100644 --- a/backend/src/v2/compiler/argocompiler/importer.go +++ b/backend/src/v2/compiler/argocompiler/importer.go @@ -16,6 +16,7 @@ package argocompiler import ( "fmt" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" @@ -76,10 +77,8 @@ func (c *workflowCompiler) addImporterTemplate() string { fmt.Sprintf("$(%s)", component.EnvPodName), "--pod_uid", fmt.Sprintf("$(%s)", component.EnvPodUID), - "--mlmd_server_address", - fmt.Sprintf("$(%s)", component.EnvMetadataHost), - "--mlmd_server_port", - fmt.Sprintf("$(%s)", component.EnvMetadataPort), + "--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(), + "--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(), } importerTemplate := &wfapi.Template{ Name: name, diff --git a/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml b/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml index d4cd73085dfc..bf1abffcebc6 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml @@ -67,6 +67,10 @@ spec: - '{{inputs.parameters.kubernetes-config}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" command: - driver image: gcr.io/ml-pipeline/kfp-driver @@ -312,6 +316,10 @@ spec: - '{{outputs.parameters.condition.path}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" command: - driver image: gcr.io/ml-pipeline/kfp-driver diff --git a/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml b/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml index e285ad07188f..985f03e14c20 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml @@ -50,6 +50,10 @@ spec: - '{{inputs.parameters.kubernetes-config}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" command: - driver image: gcr.io/ml-pipeline/kfp-driver @@ -242,6 +246,10 @@ spec: - '{{outputs.parameters.condition.path}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" env: - name: ML_PIPELINE_SERVICE_HOST value: ml-pipeline.kubeflow.svc.cluster.local diff --git a/backend/src/v2/compiler/argocompiler/testdata/importer.yaml b/backend/src/v2/compiler/argocompiler/testdata/importer.yaml index 0e2d30a12b23..7c81f24caf96 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/importer.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/importer.yaml @@ -38,9 +38,9 @@ spec: - --pod_uid - $(KFP_POD_UID) - --mlmd_server_address - - $(METADATA_GRPC_SERVICE_HOST) + - "metadata-grpc-service" - --mlmd_server_port - - $(METADATA_GRPC_SERVICE_PORT) + - "8080" command: - launcher-v2 env: @@ -120,6 +120,10 @@ spec: - '{{outputs.parameters.condition.path}}' - "--mlPipelineServiceTLSEnabled" - "false" + - "--mlmd_server_address" + - "metadata-grpc-service" + - "--mlmd_server_port" + - "8080" command: - driver image: gcr.io/ml-pipeline/kfp-driver diff --git a/backend/src/v2/component/importer_launcher.go b/backend/src/v2/component/importer_launcher.go index e6dae29d639c..c5506f0c3163 100644 --- a/backend/src/v2/component/importer_launcher.go +++ b/backend/src/v2/component/importer_launcher.go @@ -88,7 +88,7 @@ func NewImporterLauncher(ctx context.Context, componentSpecJSON, importerSpecJSO if err != nil { return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err) } - metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort) + metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MLPipelineTLSEnabled) if err != nil { return nil, err } diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index b7682d5a4e55..af25ee2c1078 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -110,7 +110,7 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co if err != nil { return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err) } - metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort) + metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLPipelineTLSEnabled) if err != nil { return nil, err } diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index b2f0e15c6a0f..7645bd88c4a7 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -77,6 +77,10 @@ type Options struct { // set to true if ml pipeline server is serving over tls MLPipelineTLSEnabled bool + + MLMDServerAddress string + + MLMDServerPort string } // Identifying information used for error messages @@ -339,7 +343,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl return execution, nil } - podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled) + podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort) if err != nil { return execution, err } @@ -373,6 +377,8 @@ func initPodSpecPatch( pipelineName string, runID string, mlPipelineTLSEnabled bool, + mlmdServerAddress string, + mlmdServerPort string, ) (*k8score.PodSpec, error) { executorInputJSON, err := protojson.Marshal(executorInput) if err != nil { @@ -407,10 +413,8 @@ func initPodSpecPatch( fmt.Sprintf("$(%s)", component.EnvPodName), "--pod_uid", fmt.Sprintf("$(%s)", component.EnvPodUID), - "--mlmd_server_address", - fmt.Sprintf("$(%s)", component.EnvMetadataHost), - "--mlmd_server_port", - fmt.Sprintf("$(%s)", component.EnvMetadataPort), + "--mlmd_server_address", mlmdServerAddress, + "--mlmd_server_port", mlmdServerPort, "--mlPipelineServiceTLSEnabled", fmt.Sprintf("%v", mlPipelineTLSEnabled), "--", // separater before user command and args diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index 34ed4d13bb30..2d4c606eee79 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -241,7 +241,7 @@ func Test_initPodSpecPatch_acceleratorConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false) + podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port") if tt.wantErr { assert.Nil(t, podSpec) assert.NotNil(t, err) @@ -403,7 +403,7 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false) + podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port") assert.Nil(t, err) assert.NotEmpty(t, podSpec) podSpecString, err := json.Marshal(podSpec) diff --git a/backend/src/v2/metadata/client.go b/backend/src/v2/metadata/client.go index eaaae44896a4..bb86b6d8ba75 100644 --- a/backend/src/v2/metadata/client.go +++ b/backend/src/v2/metadata/client.go @@ -18,8 +18,11 @@ package metadata import ( "context" + "crypto/tls" "errors" "fmt" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "path" "strconv" "strings" @@ -105,14 +108,22 @@ type Client struct { } // NewClient creates a Client given the MLMD server address and port. -func NewClient(serverAddress, serverPort string) (*Client, error) { +func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, error) { opts := []grpc_retry.CallOption{ grpc_retry.WithMax(mlmdClientSideMaxRetries), grpc_retry.WithBackoff(grpc_retry.BackoffExponentialWithJitter(300*time.Millisecond, 0.20)), grpc_retry.WithCodes(codes.Aborted), } + + creds := insecure.NewCredentials() + if tlsEnabled { + config := &tls.Config{} + creds = credentials.NewTLS(config) + } + conn, err := grpc.Dial(fmt.Sprintf("%s:%s", serverAddress, serverPort), grpc.WithInsecure(), + grpc.WithTransportCredentials(creds), grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor(opts...)), grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)), ) diff --git a/backend/src/v2/metadata/client_test.go b/backend/src/v2/metadata/client_test.go index ea3bf34dde15..d8e25efef3cc 100644 --- a/backend/src/v2/metadata/client_test.go +++ b/backend/src/v2/metadata/client_test.go @@ -84,7 +84,7 @@ func Test_GetPipeline(t *testing.T) { runUuid, err := uuid.NewRandom() fatalIf(err) runId := runUuid.String() - client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort) + client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false) fatalIf(err) mlmdClient, err := NewTestMlmdClient() fatalIf(err) @@ -135,7 +135,7 @@ func Test_GetPipeline_Twice(t *testing.T) { runUuid, err := uuid.NewRandom() fatalIf(err) runId := runUuid.String() - client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort) + client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false) fatalIf(err) pipeline, err := client.GetPipeline(ctx, "get-pipeline-test", runId, namespace, runResource, pipelineRoot, "") @@ -177,7 +177,7 @@ func Test_GetPipelineConcurrently(t *testing.T) { t.Skip("Temporarily disable the test that requires cluster connection.") // This test depends on a MLMD grpc server running at localhost:8080. - client, err := metadata.NewClient("localhost", "8080") + client, err := metadata.NewClient("localhost", "8080", false) if err != nil { t.Fatal(err) } @@ -281,7 +281,7 @@ func Test_DAG(t *testing.T) { func newLocalClientOrFatal(t *testing.T) *metadata.Client { t.Helper() - client, err := metadata.NewClient("localhost", "8080") + client, err := metadata.NewClient("localhost", "8080", false) if err != nil { t.Fatalf("metadata.NewClient failed: %v", err) }