diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 9f9830fc17..4737de413a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -10,11 +10,11 @@ import ( sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - // NOTE: this import also use things inside google.golang structpb one + // NOTE: this import also use things inside google.golang structpb one // structpb "github.com/golang/protobuf/ptypes/struct" - "google.golang.org/protobuf/types/known/structpb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -367,8 +367,6 @@ func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string Target: &core.TaskTemplate_Container{ Container: &core.Container{ Image: testImage, - Args: testArgs, - Env: dummyEnvVars, }, }, Config: map[string]string{ @@ -974,13 +972,6 @@ func TestGetPropertiesSpark(t *testing.T) { } func TestBuildResourceCustomK8SPod(t *testing.T) { - // TODO: edit below tests for custom driver and executor - // the TestBuildResourcePodTemplate test whether the custom Toleration is displayed - - // create dummy driver and executor pod - // dummy sparkJob that takes in dummy driver and executor pod - // see whether the driver and worker podSpec is what we set - // what properties to test defaultConfig := defaultPluginConfig() assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) @@ -1007,9 +998,17 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { driverK8SPod := &core.K8SPod{ PodSpec: transformStructToStructPB(t, driverPodSpec), + Metadata: &core.K8SObjectMetadata{ + Annotations: map[string]string{"annotation-driver": "val-driver"}, + Labels: map[string]string{"label-driver": "val-driver"}, + }, } executorK8SPod := &core.K8SPod{ PodSpec: transformStructToStructPB(t, executorPodSpec), + Metadata: &core.K8SObjectMetadata{ + Annotations: map[string]string{"annotation-executor": "val-executor"}, + Labels: map[string]string{"label-executor": "val-executor"}, + }, } // put the driver/executor podspec (add custom tolerations) to below function taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod) @@ -1042,19 +1041,27 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) // Driver - assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations) - assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels) + assert.Equal(t, utils.UnionMaps( + defaultConfig.DefaultAnnotations, map[string]string{ + "annotation-1": "val1", + "annotation-driver": "val-driver", + }, + ), sparkApp.Spec.Driver.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{ + "label-1": "val1", + "label-driver": "val-driver", + }), sparkApp.Spec.Driver.Labels) assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1) - // assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) - // assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET")) - // assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env)) + assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env)) assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) - // assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) - // assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig) - // assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork) - // assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) assert.Equal(t, []corev1.Toleration{ defaultConfig.DefaultTolerations[0], driverExtraToleration, @@ -1080,8 +1087,14 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) // // Executor - assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations) - assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{ + "annotation-1": "val1", + "annotation-executor": "val-executor", + }), sparkApp.Spec.Executor.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{ + "label-1": "val1", + "label-executor": "val-executor", + }), sparkApp.Spec.Executor.Labels) assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET")) @@ -1120,7 +1133,6 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) } - func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct { data, err := json.Marshal(obj) assert.Nil(t, err)