From f5d991a65c9217bbd4a68a194cddd3a18e6dbfff Mon Sep 17 00:00:00 2001 From: Taufan Adhitya Date: Wed, 2 Oct 2024 14:19:03 +0700 Subject: [PATCH 1/5] feat: refactor relation utility --- pkg/db/relation.go | 214 ++++++++++++++++++++++++++++++---------- pkg/db/relation_test.go | 169 ++++--------------------------- pkg/db/utils.go | 42 -------- 3 files changed, 183 insertions(+), 242 deletions(-) diff --git a/pkg/db/relation.go b/pkg/db/relation.go index a9367204..f58e0b32 100644 --- a/pkg/db/relation.go +++ b/pkg/db/relation.go @@ -2,89 +2,203 @@ package db import ( "fmt" + "reflect" "strings" "github.com/sev-2/raiden" ) -func (q *Query) With(r string, columns map[string][]string) *Query { +func (q *Query) Preload(table string, args ...string) *Query { - relations := strings.Split(r, ".") + relatedFieldPrefix := "" + relationMap := make(map[string]map[string]string) + + field := "" + operator := "" + value := "" + + // Override with supplied arguments if available + if len(args) > 0 && args[0] != "" { + field = args[0] + } + + if len(args) > 1 && args[1] != "" { + operator = args[1] + } + if len(args) > 2 && args[2] != "" { + value = args[2] + } + + relations := strings.Split(table, ".") + + fmt.Printf("Preloading table: %s, field: %s, operator: %s, value: %s\n", table, field, operator, value) if len(relations) > 3 { raiden.Fatal("unsupported nested relations more than 3 levels") } - for _, m := range relations { - if findModel(m) == nil { - raiden.Fatal(fmt.Sprintf("invalid model name: %s", m)) + for i, relation := range relations { + var currentModelStruct reflect.Type + var relatedModel interface{} + var err error + if i == 0 { + currentModelStruct = reflect.TypeOf(q.model) + relatedModel, err = instantiateFieldByPath(q.model, relation) + } else { + currentModelStruct = reflect.TypeOf(relatedModel) + relatedModel, err = instantiateFieldByPath(relatedModel, relation) } - } - - var selects []string - for _, r := range reverseSortString(relations) { - model := findModel(r) - table := GetTable(model) - - for k := range columns { - if strings.Contains(k, "!") { - split := strings.Split(k, "!") - m := findModelByTable(split[0]) - c := findModel(m) - if !isForeignKeyExist(c, split[1]) { - err := fmt.Sprintf("invalid foreign key: \"%s\" key is not exist.", split[1]) - raiden.Fatal(err) - } else { - table = fmt.Sprintf("%s!%s", table, split[1]) - } - } + if err != nil { + raiden.Fatal("could not find related model.") } - // Columns validations - for _, c := range columns[table] { - var column = c + fmt.Printf("Related model: %v\n", relatedModel) + relatedModelStruct := reflect.TypeOf(relatedModel) + if relatedModelStruct.Kind() == reflect.Ptr { + relatedModelStruct = relatedModelStruct.Elem() + } - if column == "*" { - continue + var relatedAlias string + var relatedTableName string + var relatedForeignKey string + for i := 0; i < relatedModelStruct.NumField(); i++ { + field := relatedModelStruct.Field(i) + if field.Name == "Metadata" { + relatedTableName = field.Tag.Get("tableName") } + } - if strings.Contains(c, ":") { - split := strings.Split(c, ":") - alias := split[0] - column = split[1] - if !isValidColumnName(alias) { - err := fmt.Sprintf("invalid alias column name: \"%s\" name is invalid.", alias) - raiden.Fatal(err) - } - } + if currentModelStruct.Kind() == reflect.Ptr { + currentModelStruct = currentModelStruct.Elem() + } - if !isColumnExist(model, column) { - err := fmt.Sprintf("invalid column: \"%s\" is not available on \"%s\" table.", column, table) - raiden.Fatal(err) - } + for i := 0; i < currentModelStruct.NumField(); i++ { + field := currentModelStruct.Field(i) + if field.Name == relation { + jsonField := field.Tag.Get("json") + join := field.Tag.Get("join") - if !isValidColumnName(column) { - err := fmt.Sprintf("invalid column: \"%s\" name is invalid.", column) - raiden.Fatal(err) + relatedAlias = strings.Split(jsonField, ",")[0] + relatedForeignKey, err = getTagValue(join, "foreignKey") + + if err != nil { + raiden.Fatal("could not find foreign key in join tag.") + } } } - cols := strings.Join(columns[table], ",") + relationData := make(map[string]string) + relationData["alias"] = relatedAlias + relationData["table"] = relatedTableName + relationData["fk"] = relatedForeignKey + relationMap[relation] = relationData + } + + var selects []string - if len(cols) == 0 { - cols = "*" + // After we have the relation map, we can construct the select query + // If the table is `Users.Team.Organization`, + // the select query will be `users(teams(organizations(*)))` + for _, r := range reverseSortString(relations) { + d := relationMap[r] + alias := d["alias"] + table := d["table"] + fk := d["fk"] + + var related string + if alias == table { + related = fmt.Sprintf("%s!%s", table, fk) + } else { + related = fmt.Sprintf("%s:%s!%s", alias, table, fk) } if len(selects) > 0 { lastQuery := selects[len(selects)-1] - selects[len(selects)-1] = fmt.Sprintf("%s(%s,%s)", table, cols, lastQuery) + selects[len(selects)-1] = fmt.Sprintf("%s(%s,%s)", related, "*", lastQuery) } else { - selects = append(selects, fmt.Sprintf("%s(%s)", table, cols)) + selects = append(selects, fmt.Sprintf("%s(%s)", related, "*")) + } + + if (relatedFieldPrefix == "") { + relatedFieldPrefix = table + } else { + relatedFieldPrefix = fmt.Sprintf("%s.%s", relatedFieldPrefix, table) } } + fmt.Println("Relations: ", relationMap) + fmt.Println("Selects: ", selects) + fmt.Println("Prefix: ", relatedFieldPrefix) q.Relations = append(q.Relations, selects...) + if field != "" && operator != "" && value != "" { + if q.WhereAndList == nil { + q.WhereAndList = &[]string{} + } + + *q.WhereAndList = append( + *q.WhereAndList, + fmt.Sprintf("%s=%s.%s", fmt.Sprintf("%s.%s", relatedFieldPrefix, field), operator, getStringValue(value)), + ) + } + return q } + +func instantiateFieldByPath(model interface{}, fieldPath string) (interface{}, error) { + fields := strings.Split(fieldPath, ".") + val := reflect.ValueOf(model) + + // If it's a pointer, dereference it, but keep track of the original value to modify it + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected a struct, got %s", val.Kind()) + } + + // Traverse the struct fields based on the field path + for _, fieldName := range fields { + fieldVal := val.FieldByName(fieldName) + + if !fieldVal.IsValid() { + return nil, fmt.Errorf("field %s not found", fieldName) + } + + if fieldVal.Kind() == reflect.Ptr { + if fieldVal.IsNil() { + fieldType := fieldVal.Type() + + fieldVal = reflect.New(fieldType.Elem()) + } + fieldVal = fieldVal.Elem() + } + + if fieldVal.Kind() != reflect.Struct { + return nil, fmt.Errorf("field %s is not a struct", fieldName) + } + + val = fieldVal + } + + newInstance := reflect.New(val.Type()).Elem().Interface() + return newInstance, nil +} + +func getTagValue(tag, key string) (string, error) { + // Split the tag by semicolon to get individual key-value pairs + pairs := strings.Split(tag, ";") + + // Iterate through the pairs to find the key + for _, pair := range pairs { + // Split the pair by colon to get key and value + kv := strings.Split(pair, ":") + if len(kv) == 2 && kv[0] == key { + return kv[1], nil // Return the value if key matches + } + } + + return "", fmt.Errorf("key %s not found in tag", key) +} diff --git a/pkg/db/relation_test.go b/pkg/db/relation_test.go index 3fc67c9f..9e02fcde 100644 --- a/pkg/db/relation_test.go +++ b/pkg/db/relation_test.go @@ -3,183 +3,52 @@ package db import ( "testing" - "github.com/sev-2/raiden/pkg/resource" "github.com/stretchr/testify/assert" ) func TestWith(t *testing.T) { - resource.RegisterModels( - ArticleMockModel{}, - UsersMockModel{}, - TeamsMockModel{}, - OrganizationsMockModel{}, - ) + articleMockModel := ArticleMockModel{} + orderMockModel := OrdersMockModel{} t.Run("match url query for single relation", func(t *testing.T) { - t.Run("without selected columns", func(t *testing.T) { + t.Run("without where condition", func(t *testing.T) { url := NewQuery(&mockRaidenContext). Model(articleMockModel). - With("UsersMockModel", nil). + Preload("User"). GetUrl() - assert.Equal(t, "/rest/v1/articles?select=*,users(*)", url) + assert.Equal(t, "/rest/v1/articles?select=*,user:users!user_id(*)", url) }) - t.Run("with selected columns", func(t *testing.T) { + t.Run("with where condition", func(t *testing.T) { url := NewQuery(&mockRaidenContext). Model(articleMockModel). - With( - "UsersMockModel", - map[string][]string{ - "users": {"id", "username"}, - }, - ). + Preload("User", "status", "eq", "approved"). GetUrl() - assert.Equal(t, "/rest/v1/articles?select=*,users(id,username)", url) - }) - - t.Run("with selected columns and aliases", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel", - map[string][]string{ - "users": {"id", "userid:username"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,userid:username)", url) - }) - }) - - t.Run("match url query for two-nested relation", func(t *testing.T) { - t.Run("without selected columns", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With("UsersMockModel.TeamsMockModel", nil). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(*,teams(*))", url) - }) - - t.Run("with selected columns", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel.TeamsMockModel", - map[string][]string{ - "users": {"id", "username"}, - "teams": {"id", "name"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,username,teams(id,name))", url) - }) - - t.Run("with selected columns and aliases", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel.TeamsMockModel", - map[string][]string{ - "users": {"id", "userid:username"}, - "teams": {"id", "team_name:name"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,userid:username,teams(id,team_name:name))", url) + assert.Equal(t, "/rest/v1/articles?select=*,user:users!user_id(*)&users.status=eq.approved", url) }) }) - t.Run("match url query for three-nested relation", func(t *testing.T) { - t.Run("without selected columns", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With("UsersMockModel.TeamsMockModel.OrganizationsMockModel", nil). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(*,teams(*,organizations(*)))", url) - }) - - t.Run("with selected columns", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel.TeamsMockModel.OrganizationsMockModel", - map[string][]string{ - "users": {"id", "username"}, - "teams": {"id", "name"}, - "organizations": {"id", "name"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,username,teams(id,name,organizations(id,name)))", url) - }) - - t.Run("with selected columns and aliases", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel.TeamsMockModel.OrganizationsMockModel", - map[string][]string{ - "users": {"id", "userid:username"}, - "teams": {"id", "team_name:name"}, - "organizations": {"id", "org_name:name"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,userid:username,teams(id,team_name:name,organizations(id,org_name:name)))", url) - }) - }) - - t.Run("match url query with foreign key", func(t *testing.T) { - - t.Run("without selected column", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(OrdersMockModel{}). - With( - "UsersMockModel", - map[string][]string{ - "users!address_id": {}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/orders?select=*,users!address_id(*)", url) - }) - - t.Run("with all columns", func(t *testing.T) { + t.Run("match url query for multiple relations", func(t *testing.T) { + t.Run("without where condition", func(t *testing.T) { url := NewQuery(&mockRaidenContext). - Model(OrdersMockModel{}). - With( - "UsersMockModel", - map[string][]string{ - "users!address_id": {"*"}, - }, - ). + Model(orderMockModel). + Preload("UserBilling"). + Preload("UserAddress"). GetUrl() - assert.Equal(t, "/rest/v1/orders?select=*,users!address_id(*)", url) + assert.Equal(t, "/rest/v1/orders?select=*,user_billing:users!billing_id(*),user_address:users!address_id(*)", url) }) - t.Run("with selected id", func(t *testing.T) { + t.Run("with where condition", func(t *testing.T) { url := NewQuery(&mockRaidenContext). - Model(OrdersMockModel{}). - With( - "UsersMockModel", - map[string][]string{ - "users!address_id": {"id", "username"}, - }, - ). + Model(orderMockModel). + Preload("UserBilling", "status", "eq", "approved"). + Preload("UserAddress", "status", "eq", "approved"). GetUrl() - assert.Equal(t, "/rest/v1/orders?select=*,users!address_id(id,username)", url) + assert.Equal(t, "/rest/v1/orders?select=*,user_billing:users!billing_id(*),user_address:users!address_id(*)&users.status=eq.approved&users.status=eq.approved", url) }) }) } diff --git a/pkg/db/utils.go b/pkg/db/utils.go index 6e539d47..26c5404d 100644 --- a/pkg/db/utils.go +++ b/pkg/db/utils.go @@ -3,54 +3,12 @@ package db import ( "log" "os" - "reflect" "sort" "strings" "github.com/sev-2/raiden" - "github.com/sev-2/raiden/pkg/resource" ) -func findModel(targetName string) interface{} { - for _, m := range resource.RegisteredModels { - t := reflect.TypeOf(m) - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - if t.Name() == targetName { - return m - } - } - - return nil -} - -func findModelByTable(table string) string { - for _, m := range resource.RegisteredModels { - - t := reflect.TypeOf(m) - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - if field.Name == "Metadata" { - tableName := field.Tag.Get("tableName") - if tableName == table { - parts := strings.Split(t.String(), ".") - return parts[len(parts)-1] - } - } - } - } - - return "" -} - func getConfig() *raiden.Config { currentDir, err := os.Getwd() if err != nil { From 14dc540aa14492e55aea0dacd344bae5974dc072 Mon Sep 17 00:00:00 2001 From: Taufan Adhitya Date: Wed, 2 Oct 2024 14:21:49 +0700 Subject: [PATCH 2/5] fix: linter --- pkg/db/relation.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/db/relation.go b/pkg/db/relation.go index f58e0b32..bf158e1e 100644 --- a/pkg/db/relation.go +++ b/pkg/db/relation.go @@ -120,7 +120,7 @@ func (q *Query) Preload(table string, args ...string) *Query { selects = append(selects, fmt.Sprintf("%s(%s)", related, "*")) } - if (relatedFieldPrefix == "") { + if relatedFieldPrefix == "" { relatedFieldPrefix = table } else { relatedFieldPrefix = fmt.Sprintf("%s.%s", relatedFieldPrefix, table) From 807b1814d9e8b1d60ef70f117e42c18739ba798b Mon Sep 17 00:00:00 2001 From: Taufan Adhitya Date: Wed, 2 Oct 2024 14:38:10 +0700 Subject: [PATCH 3/5] fix: linter --- pkg/db/select.go | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/pkg/db/select.go b/pkg/db/select.go index 77f72e0b..325aba3f 100644 --- a/pkg/db/select.go +++ b/pkg/db/select.go @@ -87,33 +87,4 @@ func isValidColumnName(column string) bool { isAllowed, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]{1,59}`, column) return isAllowed -} - -func isForeignKeyExist(m interface{}, column string) bool { - if column == "inner" { - return true - } - - t := reflect.TypeOf(m) - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - - if tagValue := field.Tag.Get("join"); tagValue != "" { - for _, part := range strings.Split(tagValue, ";") { - kv := strings.SplitN(part, ":", 2) - if len(kv) == 2 && kv[0] == "targetForeign" { - if kv[1] == column { - return true - } - } - } - } - } - - return false -} +} \ No newline at end of file From cbb1b3376d9af82a2fd8a68412ede219185e0877 Mon Sep 17 00:00:00 2001 From: Taufan Adhitya Date: Wed, 2 Oct 2024 14:39:34 +0700 Subject: [PATCH 4/5] fix: linter --- pkg/db/select.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/db/select.go b/pkg/db/select.go index 325aba3f..edb255ab 100644 --- a/pkg/db/select.go +++ b/pkg/db/select.go @@ -87,4 +87,4 @@ func isValidColumnName(column string) bool { isAllowed, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]{1,59}`, column) return isAllowed -} \ No newline at end of file +} From b9574df49f49777a2807f2fd3f7200bc858564fb Mon Sep 17 00:00:00 2001 From: Taufan Adhitya Date: Wed, 2 Oct 2024 15:27:22 +0700 Subject: [PATCH 5/5] fix: clean up --- pkg/db/relation.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pkg/db/relation.go b/pkg/db/relation.go index bf158e1e..04666b23 100644 --- a/pkg/db/relation.go +++ b/pkg/db/relation.go @@ -31,8 +31,6 @@ func (q *Query) Preload(table string, args ...string) *Query { relations := strings.Split(table, ".") - fmt.Printf("Preloading table: %s, field: %s, operator: %s, value: %s\n", table, field, operator, value) - if len(relations) > 3 { raiden.Fatal("unsupported nested relations more than 3 levels") } @@ -53,7 +51,6 @@ func (q *Query) Preload(table string, args ...string) *Query { raiden.Fatal("could not find related model.") } - fmt.Printf("Related model: %v\n", relatedModel) relatedModelStruct := reflect.TypeOf(relatedModel) if relatedModelStruct.Kind() == reflect.Ptr { relatedModelStruct = relatedModelStruct.Elem() @@ -127,9 +124,6 @@ func (q *Query) Preload(table string, args ...string) *Query { } } - fmt.Println("Relations: ", relationMap) - fmt.Println("Selects: ", selects) - fmt.Println("Prefix: ", relatedFieldPrefix) q.Relations = append(q.Relations, selects...) if field != "" && operator != "" && value != "" {