Skip to content

Commit

Permalink
UPSTREAM: <carry>: Added support for TLS to MLMD GRPC Server
Browse files Browse the repository at this point in the history
Signed-off-by: hbelmiro <[email protected]>
  • Loading branch information
hbelmiro committed Aug 30, 2024
1 parent a8fbbd2 commit a6f385c
Show file tree
Hide file tree
Showing 16 changed files with 78 additions and 22 deletions.
3 changes: 2 additions & 1 deletion backend/src/apiserver/client_manager/client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ func (c *ClientManager) init() {

c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout), clientParams)

newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort())
newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMLPipelineTLSEnabled())

if err != nil {
glog.Fatalf("Failed to create metadata client. Error: %v", err)
}
Expand Down
5 changes: 5 additions & 0 deletions backend/src/apiserver/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
MetadataGrpcServiceServiceHost string = "METADATA_GRPC_SERVICE_SERVICE_HOST"
MetadataGrpcServiceServicePort string = "METADATA_GRPC_SERVICE_SERVICE_PORT"
SignedURLExpiryTimeSeconds string = "SIGNED_URL_EXPIRY_TIME_SECONDS"
MLPipelineTLSEnabled string = "ML_PIPELINE_TLS_ENABLED"
)

func IsPipelineVersionUpdatedByDefault() bool {
Expand Down Expand Up @@ -142,3 +143,7 @@ func GetMetadataGrpcServiceServicePort() string {
func GetSignedURLExpiryTimeSeconds() int {
return GetIntConfigWithDefault(SignedURLExpiryTimeSeconds, DefaultSignedURLExpiryTimeSeconds)
}

func GetMLPipelineTLSEnabled() bool {
return GetBoolFromStringWithDefault(MLPipelineTLSEnabled, DefaultMLPipelineTLSEnabled)
}
2 changes: 2 additions & 0 deletions backend/src/apiserver/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ const (

const DefaultSignedURLExpiryTimeSeconds = 15

const DefaultMLPipelineTLSEnabled = true

const (
MaxFileNameLength = 100
MaxFileLength = 32 << 20 // 32Mb
Expand Down
10 changes: 9 additions & 1 deletion backend/src/v2/cmd/driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ func drive() (err error) {
DAGExecutionID: *dagExecutionID,
IterationIndex: *iterationIndex,
MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled,
MLMDServerAddress: *mlmdServerAddress,
MLMDServerPort: *mlmdServerPort,
}
var execution *driver.Execution
var driverErr error
Expand Down Expand Up @@ -292,5 +294,11 @@ func newMlmdClient() (*metadata.Client, error) {
mlmdConfig.Address = *mlmdServerAddress
mlmdConfig.Port = *mlmdServerPort
}
return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port)

mlPipelineServiceTLSEnabled, err := strconv.ParseBool(*mlPipelineServiceTLSEnabledStr)
if err != nil {
return nil, err
}

return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, mlPipelineServiceTLSEnabled)
}
3 changes: 3 additions & 0 deletions backend/src/v2/compiler/argocompiler/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package argocompiler
import (
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"github.com/kubeflow/pipelines/backend/src/v2/component"
k8score "k8s.io/api/core/v1"
"os"
Expand Down Expand Up @@ -163,6 +164,8 @@ func (c *workflowCompiler) addContainerDriverTemplate() string {
"--condition_path", outputPath(paramCondition),
"--kubernetes_config", inputValue(paramKubernetesConfig),
"--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
},
Resources: driverResources,
},
Expand Down
3 changes: 3 additions & 0 deletions backend/src/v2/compiler/argocompiler/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package argocompiler

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -443,6 +444,8 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
"--iteration_count_path", outputPath(paramIterationCount),
"--condition_path", outputPath(paramCondition),
"--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
},
Resources: driverResources,
},
Expand Down
7 changes: 3 additions & 4 deletions backend/src/v2/compiler/argocompiler/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package argocompiler

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
Expand Down Expand Up @@ -76,10 +77,8 @@ func (c *workflowCompiler) addImporterTemplate() string {
fmt.Sprintf("$(%s)", component.EnvPodName),
"--pod_uid",
fmt.Sprintf("$(%s)", component.EnvPodUID),
"--mlmd_server_address",
fmt.Sprintf("$(%s)", component.EnvMetadataHost),
"--mlmd_server_port",
fmt.Sprintf("$(%s)", component.EnvMetadataPort),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
}
importerTemplate := &wfapi.Template{
Name: name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ spec:
- '{{inputs.parameters.kubernetes-config}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -312,6 +316,10 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ spec:
- '{{inputs.parameters.kubernetes-config}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -242,6 +246,10 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
env:
- name: ML_PIPELINE_SERVICE_HOST
value: ml-pipeline.kubeflow.svc.cluster.local
Expand Down
8 changes: 6 additions & 2 deletions backend/src/v2/compiler/argocompiler/testdata/importer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ spec:
- --pod_uid
- $(KFP_POD_UID)
- --mlmd_server_address
- $(METADATA_GRPC_SERVICE_HOST)
- "metadata-grpc-service"
- --mlmd_server_port
- $(METADATA_GRPC_SERVICE_PORT)
- "8080"
command:
- launcher-v2
env:
Expand Down Expand Up @@ -120,6 +120,10 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
2 changes: 1 addition & 1 deletion backend/src/v2/component/importer_launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func NewImporterLauncher(ctx context.Context, componentSpecJSON, importerSpecJSO
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err)
}
metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort)
metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MLPipelineTLSEnabled)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err)
}
metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort)
metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLPipelineTLSEnabled)
if err != nil {
return nil, err
}
Expand Down
14 changes: 9 additions & 5 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ type Options struct {

// set to true if ml pipeline server is serving over tls
MLPipelineTLSEnabled bool

MLMDServerAddress string

MLMDServerPort string
}

// Identifying information used for error messages
Expand Down Expand Up @@ -339,7 +343,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl
return execution, nil
}

podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled)
podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort)
if err != nil {
return execution, err
}
Expand Down Expand Up @@ -373,6 +377,8 @@ func initPodSpecPatch(
pipelineName string,
runID string,
mlPipelineTLSEnabled bool,
mlmdServerAddress string,
mlmdServerPort string,
) (*k8score.PodSpec, error) {
executorInputJSON, err := protojson.Marshal(executorInput)
if err != nil {
Expand Down Expand Up @@ -407,10 +413,8 @@ func initPodSpecPatch(
fmt.Sprintf("$(%s)", component.EnvPodName),
"--pod_uid",
fmt.Sprintf("$(%s)", component.EnvPodUID),
"--mlmd_server_address",
fmt.Sprintf("$(%s)", component.EnvMetadataHost),
"--mlmd_server_port",
fmt.Sprintf("$(%s)", component.EnvMetadataPort),
"--mlmd_server_address", mlmdServerAddress,
"--mlmd_server_port", mlmdServerPort,
"--mlPipelineServiceTLSEnabled",
fmt.Sprintf("%v", mlPipelineTLSEnabled),
"--", // separater before user command and args
Expand Down
4 changes: 2 additions & 2 deletions backend/src/v2/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func Test_initPodSpecPatch_acceleratorConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false)
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port")
if tt.wantErr {
assert.Nil(t, podSpec)
assert.NotNil(t, err)
Expand Down Expand Up @@ -403,7 +403,7 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false)
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port")
assert.Nil(t, err)
assert.NotEmpty(t, podSpec)
podSpecString, err := json.Marshal(podSpec)
Expand Down
13 changes: 12 additions & 1 deletion backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ package metadata

import (
"context"
"crypto/tls"
"errors"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -105,14 +108,22 @@ type Client struct {
}

// NewClient creates a Client given the MLMD server address and port.
func NewClient(serverAddress, serverPort string) (*Client, error) {
func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, error) {
opts := []grpc_retry.CallOption{
grpc_retry.WithMax(mlmdClientSideMaxRetries),
grpc_retry.WithBackoff(grpc_retry.BackoffExponentialWithJitter(300*time.Millisecond, 0.20)),
grpc_retry.WithCodes(codes.Aborted),
}

creds := insecure.NewCredentials()
if tlsEnabled {
config := &tls.Config{}
creds = credentials.NewTLS(config)
}

conn, err := grpc.Dial(fmt.Sprintf("%s:%s", serverAddress, serverPort),
grpc.WithInsecure(),
grpc.WithTransportCredentials(creds),
grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor(opts...)),
grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)),
)
Expand Down
8 changes: 4 additions & 4 deletions backend/src/v2/metadata/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func Test_GetPipeline(t *testing.T) {
runUuid, err := uuid.NewRandom()
fatalIf(err)
runId := runUuid.String()
client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort)
client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false)
fatalIf(err)
mlmdClient, err := NewTestMlmdClient()
fatalIf(err)
Expand Down Expand Up @@ -135,7 +135,7 @@ func Test_GetPipeline_Twice(t *testing.T) {
runUuid, err := uuid.NewRandom()
fatalIf(err)
runId := runUuid.String()
client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort)
client, err := metadata.NewClient(testMlmdServerAddress, testMlmdServerPort, false)
fatalIf(err)

pipeline, err := client.GetPipeline(ctx, "get-pipeline-test", runId, namespace, runResource, pipelineRoot, "")
Expand Down Expand Up @@ -177,7 +177,7 @@ func Test_GetPipelineConcurrently(t *testing.T) {
t.Skip("Temporarily disable the test that requires cluster connection.")

// This test depends on a MLMD grpc server running at localhost:8080.
client, err := metadata.NewClient("localhost", "8080")
client, err := metadata.NewClient("localhost", "8080", false)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -281,7 +281,7 @@ func Test_DAG(t *testing.T) {

func newLocalClientOrFatal(t *testing.T) *metadata.Client {
t.Helper()
client, err := metadata.NewClient("localhost", "8080")
client, err := metadata.NewClient("localhost", "8080", false)
if err != nil {
t.Fatalf("metadata.NewClient failed: %v", err)
}
Expand Down

0 comments on commit a6f385c

Please sign in to comment.