diff --git a/flyteadmin/pkg/repositories/gormimpl/common.go b/flyteadmin/pkg/repositories/gormimpl/common.go index 330555be8f..b103ef0e43 100644 --- a/flyteadmin/pkg/repositories/gormimpl/common.go +++ b/flyteadmin/pkg/repositories/gormimpl/common.go @@ -115,3 +115,7 @@ func applyScopedFilters(tx *gorm.DB, inlineFilters []common.InlineFilter, mapFil } return tx, nil } + +func getIDFilter(id uint) (query string, args interface{}) { + return fmt.Sprintf("%s = ?", ID), id +} diff --git a/flyteadmin/pkg/repositories/gormimpl/execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/execution_repo.go index 69345fc06d..d34f60ff64 100644 --- a/flyteadmin/pkg/repositories/gormimpl/execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/execution_repo.go @@ -68,7 +68,7 @@ func (r *ExecutionRepo) Get(ctx context.Context, input interfaces.Identifier) (m func (r *ExecutionRepo) Update(ctx context.Context, execution models.Execution) error { timer := r.metrics.UpdateDuration.Start() - tx := r.db.WithContext(ctx).Model(&execution).Updates(execution) + tx := r.db.WithContext(ctx).Model(&models.Execution{}).Where(getIDFilter(execution.ID)).Updates(execution) timer.Stop() if err := tx.Error; err != nil { return r.errorTransformer.ToFlyteAdminError(err) diff --git a/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go index 1b4068d4f1..e1e3117c6e 100644 --- a/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/execution_repo_test.go @@ -59,10 +59,7 @@ func TestUpdateExecution(t *testing.T) { updated := false // Only match on queries that append expected filters - GlobalMock.NewMock().WithQuery(`UPDATE "executions" SET "updated_at"=$1,"execution_project"=$2,` + - `"execution_domain"=$3,"execution_name"=$4,"launch_plan_id"=$5,"workflow_id"=$6,"phase"=$7,"closure"=$8,` + - `"spec"=$9,"started_at"=$10,"execution_created_at"=$11,"execution_updated_at"=$12,"duration"=$13 WHERE "` + - `execution_project" = $14 AND "execution_domain" = $15 AND "execution_name" = $16`).WithCallback( + GlobalMock.NewMock().WithQuery(`UPDATE "executions" SET "updated_at"=$1,"execution_project"=$2,"execution_domain"=$3,"execution_name"=$4,"launch_plan_id"=$5,"workflow_id"=$6,"phase"=$7,"closure"=$8,"spec"=$9,"started_at"=$10,"execution_created_at"=$11,"execution_updated_at"=$12,"duration"=$13 WHERE id = $14`).WithCallback( func(s string, values []driver.NamedValue) { updated = true }, diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go index 70833d4d77..b1772862dc 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go @@ -97,7 +97,7 @@ func (r *NodeExecutionRepo) GetWithChildren(ctx context.Context, input interface func (r *NodeExecutionRepo) Update(ctx context.Context, nodeExecution *models.NodeExecution) error { timer := r.metrics.UpdateDuration.Start() - tx := r.db.WithContext(ctx).Model(&nodeExecution).Updates(nodeExecution) + tx := r.db.WithContext(ctx).Model(&models.NodeExecution{}).Where(getIDFilter(nodeExecution.ID)).Updates(nodeExecution) timer.Stop() if err := tx.Error; err != nil { return r.errorTransformer.ToFlyteAdminError(err) diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go index d35f8ac4f4..fe294b0a41 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -64,7 +64,7 @@ func TestUpdateNodeExecution(t *testing.T) { GlobalMock := mocket.Catcher.Reset() // Only match on queries that append the name filter nodeExecutionQuery := GlobalMock.NewMock() - nodeExecutionQuery.WithQuery(`UPDATE "node_executions" SET "id"=$1,"updated_at"=$2,"execution_project"=$3,"execution_domain"=$4,"execution_name"=$5,"node_id"=$6,"phase"=$7,"input_uri"=$8,"closure"=$9,"started_at"=$10,"node_execution_created_at"=$11,"node_execution_updated_at"=$12,"duration"=$13 WHERE "execution_project" = $14 AND "execution_domain" = $15 AND "execution_name" = $16 AND "node_id" = $17`) + nodeExecutionQuery.WithQuery(`UPDATE "node_executions" SET "id"=$1,"updated_at"=$2,"execution_project"=$3,"execution_domain"=$4,"execution_name"=$5,"node_id"=$6,"phase"=$7,"input_uri"=$8,"closure"=$9,"started_at"=$10,"node_execution_created_at"=$11,"node_execution_updated_at"=$12,"duration"=$13 WHERE id = $14`) err := nodeExecutionRepo.Update(context.Background(), &models.NodeExecution{ BaseModel: models.BaseModel{ID: 1}, diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go index c42d36b1bc..d4d30bef85 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go @@ -81,7 +81,8 @@ func (r *TaskExecutionRepo) Get(ctx context.Context, input interfaces.GetTaskExe func (r *TaskExecutionRepo) Update(ctx context.Context, execution models.TaskExecution) error { timer := r.metrics.UpdateDuration.Start() - tx := r.db.WithContext(ctx).WithContext(ctx).Save(&execution) + tx := r.db.WithContext(ctx).Model(&models.TaskExecution{}).Where(getIDFilter(execution.ID)). + Updates(&execution) timer.Stop() if err := tx.Error; err != nil { diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go index 8ccee763c2..60a2ca2077 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go @@ -85,7 +85,7 @@ func TestUpdateTaskExecution(t *testing.T) { GlobalMock.Logging = true taskExecutionQuery := GlobalMock.NewMock() - taskExecutionQuery.WithQuery(`UPDATE "task_executions" SET "id"=$1,"created_at"=$2,"updated_at"=$3,"deleted_at"=$4,"phase"=$5,"phase_version"=$6,"input_uri"=$7,"closure"=$8,"started_at"=$9,"task_execution_created_at"=$10,"task_execution_updated_at"=$11,"duration"=$12 WHERE "project" = $13 AND "domain" = $14 AND "name" = $15 AND "version" = $16 AND "execution_project" = $17 AND "execution_domain" = $18 AND "execution_name" = $19 AND "node_id" = $20 AND "retry_attempt" = $21`) + taskExecutionQuery.WithQuery(`UPDATE "task_executions" SET "updated_at"=$1,"project"=$2,"domain"=$3,"name"=$4,"version"=$5,"execution_project"=$6,"execution_domain"=$7,"execution_name"=$8,"node_id"=$9,"retry_attempt"=$10,"phase"=$11,"input_uri"=$12,"closure"=$13,"started_at"=$14,"task_execution_created_at"=$15,"task_execution_updated_at"=$16,"duration"=$17 WHERE id = $18`) err := taskExecutionRepo.Update(context.Background(), testTaskExecution) assert.NoError(t, err) assert.True(t, taskExecutionQuery.Triggered)