Skip to content

Commit

Permalink
Feat: Refactor model hook and add beforeUpdate hook registration for …
Browse files Browse the repository at this point in the history
…field hook. #55

feat:
- hook: refactor the model hook to add support for the `beforeDelete`, `afterDelete`, `beforeUpdate`, `afterUpdate`, and `beforeFind` hooks in MongoDB operations.
- plugin: add the registration of the `beforeUpdate` hook when the field hook is enabled.
  • Loading branch information
chenmingyong0423 authored Sep 10, 2024
1 parent 206e1d3 commit a9fc72d
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 4 deletions.
6 changes: 4 additions & 2 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type PluginConfig struct {

func InitPlugin(config *PluginConfig) {
if config.EnableDefaultFieldHook {
opTypes := []operation.OpType{operation.OpTypeBeforeInsert, operation.OpTypeBeforeUpsert}
opTypes := []operation.OpType{operation.OpTypeBeforeInsert, operation.OpTypeBeforeUpdate, operation.OpTypeBeforeUpsert}
for _, opType := range opTypes {
typ := opType
RegisterPlugin("mongox:default_field", func(ctx context.Context, opCtx *operation.OpContext, opts ...any) error {
Expand All @@ -57,8 +57,10 @@ func InitPlugin(config *PluginConfig) {
if config.EnableModelHook {
opTypes := []operation.OpType{
operation.OpTypeBeforeInsert, operation.OpTypeAfterInsert,
operation.OpTypeBeforeDelete, operation.OpTypeAfterDelete,
operation.OpTypeBeforeUpdate, operation.OpTypeAfterUpdate,
operation.OpTypeBeforeUpsert, operation.OpTypeAfterUpsert,
operation.OpTypeAfterFind,
operation.OpTypeBeforeFind, operation.OpTypeAfterFind,
}
for _, opType := range opTypes {
typ := opType
Expand Down
159 changes: 159 additions & 0 deletions callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,46 @@ func TestPluginInit_EnableEnableDefaultFieldHook(t *testing.T) {
require.NotZero(t, model.ID)
require.NotZero(t, model.CreatedAt)
RemovePlugin("mongox:default_field", operation.OpTypeBeforeInsert)
RemovePlugin("mongox:default_field", operation.OpTypeBeforeUpdate)
RemovePlugin("mongox:default_field", operation.OpTypeBeforeUpsert)
})
t.Run("beforeUpdate", func(t *testing.T) {
var (
model = &Model{}
m = bson.M{}
)
err := callback.GetCallback().Execute(
context.Background(),
operation.NewOpContext(nil, operation.WithDoc(model), operation.WithUpdates(m)),
operation.OpTypeBeforeUpdate,
)
require.Nil(t, err)
require.Zero(t, model.ID)
require.Zero(t, model.CreatedAt)
require.Zero(t, model.UpdatedAt)

cfg := &PluginConfig{
EnableDefaultFieldHook: true,
}
InitPlugin(cfg)

err = callback.GetCallback().Execute(
context.Background(),
operation.NewOpContext(nil, operation.WithDoc(model), operation.WithUpdates(m)),
operation.OpTypeBeforeUpdate,
)
require.Nil(t, err)
require.Zero(t, model.ID)
require.Zero(t, model.CreatedAt)
require.NotZero(t, model.UpdatedAt)
require.Equal(t, bson.M{
"$set": bson.M{
"updated_at": model.UpdatedAt,
},
}, m)

RemovePlugin("mongox:default_field", operation.OpTypeBeforeInsert)
RemovePlugin("mongox:default_field", operation.OpTypeBeforeUpdate)
RemovePlugin("mongox:default_field", operation.OpTypeBeforeUpsert)
})
t.Run("beforeUpsert", func(t *testing.T) {
Expand Down Expand Up @@ -299,6 +339,7 @@ func TestPluginInit_EnableEnableDefaultFieldHook(t *testing.T) {
}, m)

RemovePlugin("mongox:default_field", operation.OpTypeBeforeInsert)
RemovePlugin("mongox:default_field", operation.OpTypeBeforeUpdate)
RemovePlugin("mongox:default_field", operation.OpTypeBeforeUpsert)
})
}
Expand All @@ -315,6 +356,26 @@ func (t *testModelHookStruct) AfterInsert(_ context.Context) error {
return nil
}

func (t *testModelHookStruct) BeforeDelete(_ context.Context) error {
*t++
return nil
}

func (t *testModelHookStruct) AfterDelete(_ context.Context) error {
*t++
return nil
}

func (t *testModelHookStruct) BeforeUpdate(_ context.Context) error {
*t++
return nil
}

func (t *testModelHookStruct) AfterUpdate(_ context.Context) error {
*t++
return nil
}

func (t *testModelHookStruct) BeforeUpsert(_ context.Context) error {
*t++
return nil
Expand All @@ -325,6 +386,11 @@ func (t *testModelHookStruct) AfterUpsert(_ context.Context) error {
return nil
}

func (t *testModelHookStruct) BeforeFind(_ context.Context) error {
*t++
return nil
}

func (t *testModelHookStruct) AfterFind(_ context.Context) error {
*t++
return nil
Expand Down Expand Up @@ -392,6 +458,82 @@ func TestPluginInit_EnableModelHook(t *testing.T) {
wantErr: nil,
want: 2,
},
{
name: "beforeDelete",
ctx: context.Background(),
ocOption: func(tm *testModelHookStruct) []operation.OpContextOption {
return []operation.OpContextOption{
operation.WithModelHook(tm),
}
},
opType: operation.OpTypeBeforeDelete,
wantErr: nil,
want: 1,
},
{
name: "afterDelete",
ctx: context.Background(),
ocOption: func(tm *testModelHookStruct) []operation.OpContextOption {
return []operation.OpContextOption{
operation.WithModelHook(tm),
}
},
opType: operation.OpTypeAfterDelete,
wantErr: nil,
want: 1,
},
{
name: "beforeUpdate",
ctx: context.Background(),
ocOption: func(tm *testModelHookStruct) []operation.OpContextOption {
return []operation.OpContextOption{
operation.WithUpdates(tm),
}
},
opType: operation.OpTypeBeforeUpdate,
wantErr: nil,
want: 1,
},
{
name: "beforeUpdate with model hook",
ctx: context.Background(),
ocOption: func(tm *testModelHookStruct) []operation.OpContextOption {
*tm = 1
return []operation.OpContextOption{
operation.WithUpdates(new(testModelHookStruct)),
operation.WithModelHook(tm),
}
},
opType: operation.OpTypeBeforeUpdate,
wantErr: nil,
want: 2,
},
{
name: "afterUpdate",
ctx: context.Background(),
ocOption: func(tm *testModelHookStruct) []operation.OpContextOption {
return []operation.OpContextOption{
operation.WithUpdates(tm),
}
},
opType: operation.OpTypeAfterUpdate,
wantErr: nil,
want: 1,
},
{
name: "afterUpdate with model hook",
ctx: context.Background(),
ocOption: func(tm *testModelHookStruct) []operation.OpContextOption {
*tm = 1
return []operation.OpContextOption{
operation.WithUpdates(new(testModelHookStruct)),
operation.WithModelHook(tm),
}
},
opType: operation.OpTypeAfterUpdate,
wantErr: nil,
want: 2,
},
{
name: "beforeUpsert",
ctx: context.Background(),
Expand Down Expand Up @@ -444,6 +586,18 @@ func TestPluginInit_EnableModelHook(t *testing.T) {
wantErr: nil,
want: 2,
},
{
name: "beforeFind",
ctx: context.Background(),
ocOption: func(tm *testModelHookStruct) []operation.OpContextOption {
return []operation.OpContextOption{
operation.WithModelHook(tm),
}
},
opType: operation.OpTypeBeforeFind,
wantErr: nil,
want: 1,
},
{
name: "afterFind",
ctx: context.Background(),
Expand Down Expand Up @@ -500,8 +654,13 @@ func TestPluginInit_EnableModelHook(t *testing.T) {
func remoteModelPlugin() {
RemovePlugin("mongox:model", operation.OpTypeBeforeInsert)
RemovePlugin("mongox:model", operation.OpTypeAfterInsert)
RemovePlugin("mongox:model", operation.OpTypeBeforeDelete)
RemovePlugin("mongox:model", operation.OpTypeAfterDelete)
RemovePlugin("mongox:model", operation.OpTypeBeforeUpdate)
RemovePlugin("mongox:model", operation.OpTypeAfterUpdate)
RemovePlugin("mongox:model", operation.OpTypeBeforeUpsert)
RemovePlugin("mongox:model", operation.OpTypeAfterUpsert)
RemovePlugin("mongox:model", operation.OpTypeBeforeFind)
RemovePlugin("mongox:model", operation.OpTypeAfterFind)
}

Expand Down
25 changes: 23 additions & 2 deletions hook/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ func getPayload(opCtx *operation.OpContext, opType operation.OpType) any {
if opCtx.ModelHook != nil {
return opCtx.ModelHook
}

switch opType {
case operation.OpTypeBeforeInsert, operation.OpTypeAfterInsert, operation.OpTypeAfterFind:
return opCtx.Doc
case operation.OpTypeBeforeUpsert, operation.OpTypeAfterUpsert:
case operation.OpTypeBeforeUpdate, operation.OpTypeAfterUpdate, operation.OpTypeBeforeUpsert, operation.OpTypeAfterUpsert:
return opCtx.Updates
default:
return nil
return opCtx.ModelHook
}
}

Expand Down Expand Up @@ -81,6 +82,22 @@ func execute(ctx context.Context, doc any, opType operation.OpType, _ ...any) er
if m, ok := doc.(AfterInsert); ok {
return m.AfterInsert(ctx)
}
case operation.OpTypeBeforeDelete:
if m, ok := doc.(BeforeDelete); ok {
return m.BeforeDelete(ctx)
}
case operation.OpTypeAfterDelete:
if m, ok := doc.(AfterDelete); ok {
return m.AfterDelete(ctx)
}
case operation.OpTypeBeforeUpdate:
if m, ok := doc.(BeforeUpdate); ok {
return m.BeforeUpdate(ctx)
}
case operation.OpTypeAfterUpdate:
if m, ok := doc.(AfterUpdate); ok {
return m.AfterUpdate(ctx)
}
case operation.OpTypeBeforeUpsert:
if m, ok := doc.(BeforeUpsert); ok {
return m.BeforeUpsert(ctx)
Expand All @@ -89,6 +106,10 @@ func execute(ctx context.Context, doc any, opType operation.OpType, _ ...any) er
if m, ok := doc.(AfterUpsert); ok {
return m.AfterUpsert(ctx)
}
case operation.OpTypeBeforeFind:
if m, ok := doc.(BeforeFind); ok {
return m.BeforeFind(ctx)
}
case operation.OpTypeAfterFind:
if m, ok := doc.(AfterFind); ok {
return m.AfterFind(ctx)
Expand Down
Loading

0 comments on commit a9fc72d

Please sign in to comment.