Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: APIs for Backend Changes for Default Values #965

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions conversion/conversion_from_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type DataFromSourceImpl struct{}
func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) {
conv := internal.MakeConv()
conv.SpDialect = targetProfile.Conn.Sp.Dialect
conv.SpProjectId = targetProfile.Conn.Sp.Project
conv.SpInstanceId = targetProfile.Conn.Sp.Instance
conv.Source = sourceProfile.Driver
//handle fetching schema differently for sharded migrations, we only connect to the primary shard to
//fetch the schema. We reuse the SourceProfileConnection object for this purpose.
var infoSchema common.InfoSchema
Expand Down Expand Up @@ -159,6 +162,9 @@ func (sads *DataFromSourceImpl) dataFromCSV(ctx context.Context, sourceProfile p
return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source")
}
conv.SpDialect = targetProfile.Conn.Sp.Dialect
conv.SpProjectId = targetProfile.Conn.Sp.Project
darshan-sj marked this conversation as resolved.
Show resolved Hide resolved
conv.SpInstanceId = targetProfile.Conn.Sp.Instance
conv.Source = sourceProfile.Driver
dialect, err := targetProfile.FetchTargetDialect(ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch dialect: %v", err)
Expand Down
107 changes: 104 additions & 3 deletions expressions_api/expression_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"sync"

spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client"
spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/task"
Expand All @@ -18,22 +19,49 @@ const THREAD_POOL = 500
type ExpressionVerificationAccessor interface {
//Batch API which parallelizes expression verification calls
VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput
RefreshSpannerClient(ctx context.Context, project string, instance string) error
darshan-sj marked this conversation as resolved.
Show resolved Hide resolved
}

type ExpressionVerificationAccessorImpl struct {
SpannerAccessor *spanneraccessor.SpannerAccessorImpl
}

func NewExpressionVerificationAccessorImpl(ctx context.Context, project string, instance string) (*ExpressionVerificationAccessorImpl, error) {
spannerAccessor, err := spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db"))
if err != nil {
return nil, err
var spannerAccessor *spanneraccessor.SpannerAccessorImpl
var err error
if project != "" && instance != "" {
spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db"))
if err != nil {
return nil, err
}
} else {
spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImpl(ctx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont have full context but Why are we multiplexing this here? Seems like this project id based if else logic should reside inside NewSpannerAccessorClientImpl() that creates accessor without spanner client if fields are empty.

if err != nil {
return nil, err
}
}
return &ExpressionVerificationAccessorImpl{
SpannerAccessor: spannerAccessor,
}, nil
}

type DDLVerifier interface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in general its good practice to document behaviour for methods in interface.

VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error)
GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail
GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail
RefreshSpannerClient(ctx context.Context, project string, instance string) error
}
type DDLVerifierImpl struct {
Expressions ExpressionVerificationAccessor
}

func NewDDLVerifierImpl(ctx context.Context, project string, instance string) (*DDLVerifierImpl, error) {
expVerifier, err := NewExpressionVerificationAccessorImpl(ctx, project, instance)
return &DDLVerifierImpl{
Expressions: expVerifier,
}, err
}

func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
err := ev.validateRequest(verifyExpressionsInput)
if err != nil {
Expand Down Expand Up @@ -79,6 +107,15 @@ func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Cont
return verifyExpressionsOutput
}

func (ev *ExpressionVerificationAccessorImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error {
spannerClient, err := spannerclient.NewSpannerClientImpl(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db"))
darshan-sj marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"smt-staging-db" should be a constant string across the codebase.

if err != nil {
return err
}
ev.SpannerAccessor.SpannerClient = spannerClient
return nil
}

func (ev *ExpressionVerificationAccessorImpl) verifyExpressionInternal(expressionDetail internal.ExpressionDetail, mutex *sync.Mutex) task.TaskResult[internal.ExpressionVerificationOutput] {
var sqlStatement string
switch expressionDetail.Type {
Expand Down Expand Up @@ -129,3 +166,67 @@ func (ev *ExpressionVerificationAccessorImpl) removeExpressions(inputConv *inter
}
return convCopy, nil
}

func (ddlv *DDLVerifierImpl) VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error) {
ctx := context.Background()
verifyExpressionsInput := internal.VerifyExpressionsInput{
Conv: conv,
Source: conv.Source,
ExpressionDetailList: expressionDetails,
}
verificationResults := ddlv.Expressions.VerifyExpressions(ctx, verifyExpressionsInput)

return verificationResults, verificationResults.Err
}

func (ddlv *DDLVerifierImpl) GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail {
expressionDetails := []internal.ExpressionDetail{}
// Collect default values for verification
for _, tableId := range tableIds {
srcTable := conv.SrcSchema[tableId]
for _, srcColId := range srcTable.ColIds {
srcCol := srcTable.ColDefs[srcColId]
if srcCol.DefaultValue.IsPresent {
defaultValueExp := internal.ExpressionDetail{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema[tableId].ColDefs[srcColId].T.Name,
},
ExpressionId: srcCol.DefaultValue.Value.ExpressionId,
Expression: srcCol.DefaultValue.Value.Query,
Type: "DEFAULT",
Metadata: map[string]string{"TableId": tableId, "ColId": srcColId},
}
expressionDetails = append(expressionDetails, defaultValueExp)
}
}
}
return expressionDetails
}

func (ddlv *DDLVerifierImpl) GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail {
expressionDetails := []internal.ExpressionDetail{}
// Collect default values for verification
for _, tableId := range tableIds {
spTable := conv.SpSchema[tableId]
for _, spColId := range spTable.ColIds {
spCol := spTable.ColDefs[spColId]
if spCol.DefaultValue.IsPresent {
darshan-sj marked this conversation as resolved.
Show resolved Hide resolved
defaultValueExp := internal.ExpressionDetail{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema[tableId].ColDefs[spColId].T.Name,
},
ExpressionId: spCol.DefaultValue.Value.ExpressionId,
Expression: spCol.DefaultValue.Value.Query,
Type: "DEFAULT",
Metadata: map[string]string{"TableId": tableId, "ColId": spColId},
}
expressionDetails = append(expressionDetails, defaultValueExp)
}
}
}
return expressionDetails
}

func (ddlv *DDLVerifierImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error {
return ddlv.Expressions.RefreshSpannerClient(ctx, project, instance)
}
187 changes: 185 additions & 2 deletions expressions_api/expression_verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api"
"github.com/GoogleCloudPlatform/spanner-migration-tool/internal"
"github.com/GoogleCloudPlatform/spanner-migration-tool/logger"
"github.com/GoogleCloudPlatform/spanner-migration-tool/schema"
"github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl"
"github.com/googleapis/gax-go/v2"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
Expand All @@ -32,8 +34,8 @@ func TestVerifyExpressions(t *testing.T) {
conv := internal.MakeConv()
ReadSessionFile(conv, "../../test_data/session_expression_verify.json")
input := internal.VerifyExpressionsInput{
Conv: conv,
Source: "mysql",
Conv: conv,
Source: "mysql",
ExpressionDetailList: []internal.ExpressionDetail{
{
Expression: "id > 10",
Expand Down Expand Up @@ -297,3 +299,184 @@ func ReadSessionFile(conv *internal.Conv, sessionJSON string) error {
}
return nil
}

func TestVerifySpannerDDL(t *testing.T) {
conv := *internal.MakeConv()
testCases := []struct {
name string
conv internal.Conv
expressionDetails []internal.ExpressionDetail
verifyExpressionMock expressions_api.MockExpressionVerificationAccessor
errorExpected bool
}{
{
name: "no error flow",
conv: conv,
expressionDetails: []internal.ExpressionDetail{},
verifyExpressionMock: expressions_api.MockExpressionVerificationAccessor{
VerifyExpressionsMock: func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
return internal.VerifyExpressionsOutput{
ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{},
Err: nil,
}
},
},
errorExpected: false,
},
{
name: "error flow",
conv: conv,
expressionDetails: []internal.ExpressionDetail{},
verifyExpressionMock: expressions_api.MockExpressionVerificationAccessor{
VerifyExpressionsMock: func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
return internal.VerifyExpressionsOutput{
ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{},
Err: fmt.Errorf("error"),
}
},
},
errorExpected: true,
},
}

for _, tc := range testCases {
ddlV := expressions_api.DDLVerifierImpl{
Expressions: &tc.verifyExpressionMock,
}
_, err := ddlV.VerifySpannerDDL(&tc.conv, tc.expressionDetails)
assert.Equal(t, tc.errorExpected, err != nil)
}
}

func TestGetSourceExpressionDetails(t *testing.T) {
conv := internal.MakeConv()
conv.SrcSchema = map[string]schema.Table{
"table1": {
ColIds: []string{"col1", "col2"},
ColDefs: map[string]schema.Column{
"col1": {
DefaultValue: ddl.DefaultValue{
IsPresent: true,
Value: ddl.Expression{
ExpressionId: "expr1",
Query: "SELECT 1",
},
},
},
"col2": {
DefaultValue: ddl.DefaultValue{},
},
},
},
}
conv.SpSchema = ddl.Schema{
"table1": {
ColDefs: map[string]ddl.ColumnDef{
"col1": {
T: ddl.Type{
Name: "INT64",
},
},
},
},
}

testCases := []struct {
name string
conv *internal.Conv
tableIds []string
expectedDetails []internal.ExpressionDetail
}{
{
name: "single table with default value",
conv: conv,
tableIds: []string{"table1"},
expectedDetails: []internal.ExpressionDetail{
{
ReferenceElement: internal.ReferenceElement{
Name: "INT64",
},
ExpressionId: "expr1",
Expression: "SELECT 1",
Type: "DEFAULT",
Metadata: map[string]string{"TableId": "table1", "ColId": "col1"},
},
},
},
{
name: "no tables",
conv: conv,
tableIds: []string{},
expectedDetails: []internal.ExpressionDetail{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ddlv := &expressions_api.DDLVerifierImpl{}
actualDetails := ddlv.GetSourceExpressionDetails(tc.conv, tc.tableIds)
assert.Equal(t, tc.expectedDetails, actualDetails)
})
}
}

func TestGetSpannerExpressionDetails(t *testing.T) {
conv := internal.MakeConv()
conv.SpSchema = ddl.Schema{
"table1": {
ColIds: []string{"col1", "col2"},
ColDefs: map[string]ddl.ColumnDef{
"col1": {
DefaultValue: ddl.DefaultValue{
IsPresent: true,
Value: ddl.Expression{
ExpressionId: "expr1",
Query: "SELECT 1",
},
},
},
"col2": {
DefaultValue: ddl.DefaultValue{},
},
},
},
}

testCases := []struct {
name string
conv *internal.Conv
tableIds []string
expectedDetails []internal.ExpressionDetail
}{
{
name: "single table with default value",
conv: conv,
tableIds: []string{"table1"},
expectedDetails: []internal.ExpressionDetail{
{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema["table1"].ColDefs["col1"].T.Name,
},
ExpressionId: "expr1",
Expression: "SELECT 1",
Type: "DEFAULT",
Metadata: map[string]string{"TableId": "table1", "ColId": "col1"},
},
},
},
{
name: "no tables",
conv: conv,
tableIds: []string{},
expectedDetails: []internal.ExpressionDetail{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ddlv := &expressions_api.DDLVerifierImpl{}
actualDetails := ddlv.GetSpannerExpressionDetails(tc.conv, tc.tableIds)
assert.Equal(t, tc.expectedDetails, actualDetails)
})
}
}
Loading
Loading