Skip to content

Commit

Permalink
[flytepropeller] Add tests in v1alpha (#5896)
Browse files Browse the repository at this point in the history
* test: add more tests in flytepropeller

Signed-off-by: DenChenn <[email protected]>

* test: refactor previous json marshal test

Signed-off-by: DenChenn <[email protected]>

* refactor: fix wrong comment

Signed-off-by: DenChenn <[email protected]>

* refactor: lint

Signed-off-by: DenChenn <[email protected]>

---------

Signed-off-by: DenChenn <[email protected]>
  • Loading branch information
DenChenn authored Nov 11, 2024
1 parent 25cfe16 commit 9889b0e
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 14 deletions.
32 changes: 19 additions & 13 deletions flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package v1alpha1

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -10,21 +9,28 @@ import (
)

func TestExecutionErrorJSONMarshalling(t *testing.T) {
execError := &core.ExecutionError{
Code: "TestCode",
Message: "Test error message",
ErrorUri: "Test error uri",
execError := ExecutionError{
&core.ExecutionError{
Code: "TestCode",
Message: "Test error message",
ErrorUri: "Test error uri",
},
}

execErr := &ExecutionError{ExecutionError: execError}
data, jErr := json.Marshal(execErr)
assert.Nil(t, jErr)
expected, mockErr := mockMarshalPbToBytes(execError.ExecutionError)
assert.Nil(t, mockErr)

newExecErr := &ExecutionError{}
uErr := json.Unmarshal(data, newExecErr)
// MarshalJSON
execErrorBytes, mErr := execError.MarshalJSON()
assert.Nil(t, mErr)
assert.Equal(t, expected, execErrorBytes)

// UnmarshalJSON
execErrorObj := &ExecutionError{}
uErr := execErrorObj.UnmarshalJSON(execErrorBytes)
assert.Nil(t, uErr)

assert.Equal(t, execError.Code, newExecErr.ExecutionError.Code)
assert.Equal(t, execError.Message, newExecErr.ExecutionError.Message)
assert.Equal(t, execError.ErrorUri, newExecErr.ExecutionError.ErrorUri)
assert.Equal(t, execError.Code, execErrorObj.Code)
assert.Equal(t, execError.Message, execError.Message)
assert.Equal(t, execError.ErrorUri, execErrorObj.ErrorUri)
}
150 changes: 150 additions & 0 deletions flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package v1alpha1

import (
"bytes"
"testing"

"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/durationpb"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

func mockMarshalPbToBytes(msg proto.Message) ([]byte, error) {
var buf bytes.Buffer
jMarshaller := jsonpb.Marshaler{}
if err := jMarshaller.Marshal(&buf, msg); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func TestApproveConditionJSONMarshalling(t *testing.T) {
approveCondition := ApproveCondition{
&core.ApproveCondition{
SignalId: "TestSignalId",
},
}

expected, mockErr := mockMarshalPbToBytes(approveCondition.ApproveCondition)
assert.Nil(t, mockErr)

// MarshalJSON
approveConditionBytes, mErr := approveCondition.MarshalJSON()
assert.Nil(t, mErr)
assert.Equal(t, expected, approveConditionBytes)

// UnmarshalJSON
approveConditionObj := &ApproveCondition{}
uErr := approveConditionObj.UnmarshalJSON(approveConditionBytes)
assert.Nil(t, uErr)
assert.Equal(t, approveCondition.SignalId, approveConditionObj.SignalId)
}

func TestSignalConditionJSONMarshalling(t *testing.T) {
signalCondition := SignalCondition{
&core.SignalCondition{
SignalId: "TestSignalId",
},
}

expected, mockErr := mockMarshalPbToBytes(signalCondition.SignalCondition)
assert.Nil(t, mockErr)

// MarshalJSON
signalConditionBytes, mErr := signalCondition.MarshalJSON()
assert.Nil(t, mErr)
assert.Equal(t, expected, signalConditionBytes)

// UnmarshalJSON
signalConditionObj := &SignalCondition{}
uErr := signalConditionObj.UnmarshalJSON(signalConditionBytes)
assert.Nil(t, uErr)
assert.Equal(t, signalCondition.SignalId, signalConditionObj.SignalId)
}

func TestSleepConditionJSONMarshalling(t *testing.T) {
sleepCondition := SleepCondition{
&core.SleepCondition{
Duration: &durationpb.Duration{
Seconds: 10,
Nanos: 10,
},
},
}

expected, mockErr := mockMarshalPbToBytes(sleepCondition.SleepCondition)
assert.Nil(t, mockErr)

// MarshalJSON
sleepConditionBytes, mErr := sleepCondition.MarshalJSON()
assert.Nil(t, mErr)
assert.Equal(t, expected, sleepConditionBytes)

// UnmarshalJSON
sleepConditionObj := &SleepCondition{}
uErr := sleepConditionObj.UnmarshalJSON(sleepConditionBytes)
assert.Nil(t, uErr)
assert.Equal(t, sleepCondition.Duration, sleepConditionObj.Duration)
}

func TestGateNodeSpec_GetKind(t *testing.T) {
kind := ConditionKindApprove
gateNodeSpec := GateNodeSpec{
Kind: kind,
}

if gateNodeSpec.GetKind() != kind {
t.Errorf("Expected %s, but got %s", kind, gateNodeSpec.GetKind())
}
}

func TestGateNodeSpec_GetApprove(t *testing.T) {
approveCondition := &ApproveCondition{
&core.ApproveCondition{
SignalId: "TestSignalId",
},
}
gateNodeSpec := GateNodeSpec{
Approve: approveCondition,
}

if gateNodeSpec.GetApprove() != approveCondition.ApproveCondition {
t.Errorf("Expected approveCondition, but got a different value")
}
}

func TestGateNodeSpec_GetSignal(t *testing.T) {
signalCondition := &SignalCondition{
&core.SignalCondition{
SignalId: "TestSignalId",
},
}
gateNodeSpec := GateNodeSpec{
Signal: signalCondition,
}

if gateNodeSpec.GetSignal() != signalCondition.SignalCondition {
t.Errorf("Expected signalCondition, but got a different value")
}
}

func TestGateNodeSpec_GetSleep(t *testing.T) {
sleepCondition := &SleepCondition{
&core.SleepCondition{
Duration: &durationpb.Duration{
Seconds: 10,
Nanos: 10,
},
},
}
gateNodeSpec := GateNodeSpec{
Sleep: sleepCondition,
}

if gateNodeSpec.GetSleep() != sleepCondition.SleepCondition {
t.Errorf("Expected sleepCondition, but got a different value")
}
}
115 changes: 115 additions & 0 deletions flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package v1alpha1

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

func TestIdentifierJSONMarshalling(t *testing.T) {
identifier := Identifier{
&core.Identifier{
ResourceType: core.ResourceType_TASK,
Project: "TestProject",
Domain: "TestDomain",
Name: "TestName",
Version: "TestVersion",
},
}

expected, mockErr := mockMarshalPbToBytes(identifier.Identifier)
assert.Nil(t, mockErr)

// MarshalJSON
identifierBytes, mErr := identifier.MarshalJSON()
assert.Nil(t, mErr)
assert.Equal(t, expected, identifierBytes)

// UnmarshalJSON
identifierObj := &Identifier{}
uErr := identifierObj.UnmarshalJSON(identifierBytes)
assert.Nil(t, uErr)
assert.Equal(t, identifier.Project, identifierObj.Project)
assert.Equal(t, identifier.Domain, identifierObj.Domain)
assert.Equal(t, identifier.Name, identifierObj.Name)
assert.Equal(t, identifier.Version, identifierObj.Version)
}

func TestIdentifier_DeepCopyInto(t *testing.T) {
identifier := Identifier{
&core.Identifier{
ResourceType: core.ResourceType_TASK,
Project: "TestProject",
Domain: "TestDomain",
Name: "TestName",
Version: "TestVersion",
},
}

identifierCopy := Identifier{}
identifier.DeepCopyInto(&identifierCopy)
assert.Equal(t, identifier.Project, identifierCopy.Project)
assert.Equal(t, identifier.Domain, identifierCopy.Domain)
assert.Equal(t, identifier.Name, identifierCopy.Name)
assert.Equal(t, identifier.Version, identifierCopy.Version)
}

func TestWorkflowExecutionIdentifier_DeepCopyInto(t *testing.T) {
weIdentifier := WorkflowExecutionIdentifier{
&core.WorkflowExecutionIdentifier{
Project: "TestProject",
Domain: "TestDomain",
Name: "TestName",
Org: "TestOrg",
},
}

weIdentifierCopy := WorkflowExecutionIdentifier{}
weIdentifier.DeepCopyInto(&weIdentifierCopy)
assert.Equal(t, weIdentifier.Project, weIdentifierCopy.Project)
assert.Equal(t, weIdentifier.Domain, weIdentifierCopy.Domain)
assert.Equal(t, weIdentifier.Name, weIdentifierCopy.Name)
assert.Equal(t, weIdentifier.Org, weIdentifierCopy.Org)
}

func TestTaskExecutionIdentifier_DeepCopyInto(t *testing.T) {
teIdentifier := TaskExecutionIdentifier{
&core.TaskExecutionIdentifier{
TaskId: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Project: "TestProject",
Domain: "TestDomain",
Name: "TestName",
Version: "TestVersion",
Org: "TestOrg",
},
NodeExecutionId: &core.NodeExecutionIdentifier{
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "TestProject",
Domain: "TestDomain",
Name: "TestName",
Org: "TestOrg",
},
NodeId: "TestNodeId",
},
RetryAttempt: 1,
},
}

teIdentifierCopy := TaskExecutionIdentifier{}
teIdentifier.DeepCopyInto(&teIdentifierCopy)
assert.Equal(t, teIdentifier.TaskId.ResourceType, teIdentifierCopy.TaskId.ResourceType)
assert.Equal(t, teIdentifier.TaskId.Project, teIdentifierCopy.TaskId.Project)
assert.Equal(t, teIdentifier.TaskId.Domain, teIdentifierCopy.TaskId.Domain)
assert.Equal(t, teIdentifier.TaskId.Name, teIdentifierCopy.TaskId.Name)
assert.Equal(t, teIdentifier.TaskId.Version, teIdentifierCopy.TaskId.Version)
assert.Equal(t, teIdentifier.TaskId.Org, teIdentifierCopy.TaskId.Org)
assert.Equal(t, teIdentifier.NodeExecutionId.ExecutionId.Project, teIdentifierCopy.NodeExecutionId.ExecutionId.Project)
assert.Equal(t, teIdentifier.NodeExecutionId.ExecutionId.Domain, teIdentifierCopy.NodeExecutionId.ExecutionId.Domain)
assert.Equal(t, teIdentifier.NodeExecutionId.ExecutionId.Name, teIdentifierCopy.NodeExecutionId.ExecutionId.Name)
assert.Equal(t, teIdentifier.NodeExecutionId.ExecutionId.Org, teIdentifierCopy.NodeExecutionId.ExecutionId.Org)
assert.Equal(t, teIdentifier.NodeExecutionId.NodeId, teIdentifierCopy.NodeExecutionId.NodeId)
assert.Equal(t, teIdentifier.RetryAttempt, teIdentifierCopy.RetryAttempt)
}
2 changes: 1 addition & 1 deletion flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const FlyteWorkflowKind = "flyteworkflow"
// SchemeGroupVersion is group version used to register these objects
var SchemeGroupVersion = schema.GroupVersion{Group: flyteworkflow.GroupName, Version: "v1alpha1"}

// GetKind takes an unqualified kind and returns back a Group qualified GroupKind
// Kind takes an unqualified kind and returns back a Group qualified GroupKind
func Kind(kind string) schema.GroupKind {
return SchemeGroupVersion.WithKind(kind).GroupKind()
}
Expand Down
28 changes: 28 additions & 0 deletions flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package v1alpha1

import (
"testing"

"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/runtime"
)

func TestKind(t *testing.T) {
kind := "test kind"
got := Kind(kind)
want := SchemeGroupVersion.WithKind(kind).GroupKind()
assert.Equal(t, got, want)
}

func TestResource(t *testing.T) {
resource := "test resource"
got := Resource(resource)
want := SchemeGroupVersion.WithResource(resource).GroupResource()
assert.Equal(t, got, want)
}

func Test_addKnownTypes(t *testing.T) {
scheme := runtime.NewScheme()
err := addKnownTypes(scheme)
assert.Nil(t, err)
}
36 changes: 36 additions & 0 deletions flytepropeller/pkg/apis/flyteworkflow/v1alpha1/subworkflow_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package v1alpha1

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

func TestWorkflowNodeSpec_GetLaunchPlanRefID(t *testing.T) {
wfNodeSpec := &WorkflowNodeSpec{
LaunchPlanRefID: &LaunchPlanRefID{
&core.Identifier{
Project: "TestProject",
},
},
}

nilWfNodeSpec := &WorkflowNodeSpec{}

assert.Equal(t, wfNodeSpec.GetLaunchPlanRefID(), wfNodeSpec.LaunchPlanRefID)
assert.Empty(t, nilWfNodeSpec.GetLaunchPlanRefID())
}

func TestWorkflowNodeSpec_GetSubWorkflowRef(t *testing.T) {
workflowID := "TestWorkflowID"
wfNodeSpec := &WorkflowNodeSpec{
SubWorkflowReference: &workflowID,
}

nilWfNodeSpec := &WorkflowNodeSpec{}

assert.Equal(t, wfNodeSpec.GetSubWorkflowRef(), wfNodeSpec.SubWorkflowReference)
assert.Empty(t, nilWfNodeSpec.GetSubWorkflowRef())
}

0 comments on commit 9889b0e

Please sign in to comment.