diff --git a/pkg/db/relation.go b/pkg/db/relation.go index a9367204..04666b23 100644 --- a/pkg/db/relation.go +++ b/pkg/db/relation.go @@ -2,89 +2,197 @@ package db import ( "fmt" + "reflect" "strings" "github.com/sev-2/raiden" ) -func (q *Query) With(r string, columns map[string][]string) *Query { +func (q *Query) Preload(table string, args ...string) *Query { - relations := strings.Split(r, ".") + relatedFieldPrefix := "" + relationMap := make(map[string]map[string]string) + + field := "" + operator := "" + value := "" + + // Override with supplied arguments if available + if len(args) > 0 && args[0] != "" { + field = args[0] + } + + if len(args) > 1 && args[1] != "" { + operator = args[1] + } + if len(args) > 2 && args[2] != "" { + value = args[2] + } + + relations := strings.Split(table, ".") if len(relations) > 3 { raiden.Fatal("unsupported nested relations more than 3 levels") } - for _, m := range relations { - if findModel(m) == nil { - raiden.Fatal(fmt.Sprintf("invalid model name: %s", m)) + for i, relation := range relations { + var currentModelStruct reflect.Type + var relatedModel interface{} + var err error + if i == 0 { + currentModelStruct = reflect.TypeOf(q.model) + relatedModel, err = instantiateFieldByPath(q.model, relation) + } else { + currentModelStruct = reflect.TypeOf(relatedModel) + relatedModel, err = instantiateFieldByPath(relatedModel, relation) } - } - - var selects []string - for _, r := range reverseSortString(relations) { - model := findModel(r) - table := GetTable(model) - - for k := range columns { - if strings.Contains(k, "!") { - split := strings.Split(k, "!") - m := findModelByTable(split[0]) - c := findModel(m) - if !isForeignKeyExist(c, split[1]) { - err := fmt.Sprintf("invalid foreign key: \"%s\" key is not exist.", split[1]) - raiden.Fatal(err) - } else { - table = fmt.Sprintf("%s!%s", table, split[1]) - } - } + if err != nil { + raiden.Fatal("could not find related model.") } - // Columns validations - for _, c := range columns[table] { - var column = c + relatedModelStruct := reflect.TypeOf(relatedModel) + if relatedModelStruct.Kind() == reflect.Ptr { + relatedModelStruct = relatedModelStruct.Elem() + } - if column == "*" { - continue + var relatedAlias string + var relatedTableName string + var relatedForeignKey string + for i := 0; i < relatedModelStruct.NumField(); i++ { + field := relatedModelStruct.Field(i) + if field.Name == "Metadata" { + relatedTableName = field.Tag.Get("tableName") } + } - if strings.Contains(c, ":") { - split := strings.Split(c, ":") - alias := split[0] - column = split[1] - if !isValidColumnName(alias) { - err := fmt.Sprintf("invalid alias column name: \"%s\" name is invalid.", alias) - raiden.Fatal(err) - } - } + if currentModelStruct.Kind() == reflect.Ptr { + currentModelStruct = currentModelStruct.Elem() + } - if !isColumnExist(model, column) { - err := fmt.Sprintf("invalid column: \"%s\" is not available on \"%s\" table.", column, table) - raiden.Fatal(err) - } + for i := 0; i < currentModelStruct.NumField(); i++ { + field := currentModelStruct.Field(i) + if field.Name == relation { + jsonField := field.Tag.Get("json") + join := field.Tag.Get("join") + + relatedAlias = strings.Split(jsonField, ",")[0] + relatedForeignKey, err = getTagValue(join, "foreignKey") - if !isValidColumnName(column) { - err := fmt.Sprintf("invalid column: \"%s\" name is invalid.", column) - raiden.Fatal(err) + if err != nil { + raiden.Fatal("could not find foreign key in join tag.") + } } } - cols := strings.Join(columns[table], ",") + relationData := make(map[string]string) + relationData["alias"] = relatedAlias + relationData["table"] = relatedTableName + relationData["fk"] = relatedForeignKey + relationMap[relation] = relationData + } + + var selects []string - if len(cols) == 0 { - cols = "*" + // After we have the relation map, we can construct the select query + // If the table is `Users.Team.Organization`, + // the select query will be `users(teams(organizations(*)))` + for _, r := range reverseSortString(relations) { + d := relationMap[r] + alias := d["alias"] + table := d["table"] + fk := d["fk"] + + var related string + if alias == table { + related = fmt.Sprintf("%s!%s", table, fk) + } else { + related = fmt.Sprintf("%s:%s!%s", alias, table, fk) } if len(selects) > 0 { lastQuery := selects[len(selects)-1] - selects[len(selects)-1] = fmt.Sprintf("%s(%s,%s)", table, cols, lastQuery) + selects[len(selects)-1] = fmt.Sprintf("%s(%s,%s)", related, "*", lastQuery) + } else { + selects = append(selects, fmt.Sprintf("%s(%s)", related, "*")) + } + + if relatedFieldPrefix == "" { + relatedFieldPrefix = table } else { - selects = append(selects, fmt.Sprintf("%s(%s)", table, cols)) + relatedFieldPrefix = fmt.Sprintf("%s.%s", relatedFieldPrefix, table) } } q.Relations = append(q.Relations, selects...) + if field != "" && operator != "" && value != "" { + if q.WhereAndList == nil { + q.WhereAndList = &[]string{} + } + + *q.WhereAndList = append( + *q.WhereAndList, + fmt.Sprintf("%s=%s.%s", fmt.Sprintf("%s.%s", relatedFieldPrefix, field), operator, getStringValue(value)), + ) + } + return q } + +func instantiateFieldByPath(model interface{}, fieldPath string) (interface{}, error) { + fields := strings.Split(fieldPath, ".") + val := reflect.ValueOf(model) + + // If it's a pointer, dereference it, but keep track of the original value to modify it + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected a struct, got %s", val.Kind()) + } + + // Traverse the struct fields based on the field path + for _, fieldName := range fields { + fieldVal := val.FieldByName(fieldName) + + if !fieldVal.IsValid() { + return nil, fmt.Errorf("field %s not found", fieldName) + } + + if fieldVal.Kind() == reflect.Ptr { + if fieldVal.IsNil() { + fieldType := fieldVal.Type() + + fieldVal = reflect.New(fieldType.Elem()) + } + fieldVal = fieldVal.Elem() + } + + if fieldVal.Kind() != reflect.Struct { + return nil, fmt.Errorf("field %s is not a struct", fieldName) + } + + val = fieldVal + } + + newInstance := reflect.New(val.Type()).Elem().Interface() + return newInstance, nil +} + +func getTagValue(tag, key string) (string, error) { + // Split the tag by semicolon to get individual key-value pairs + pairs := strings.Split(tag, ";") + + // Iterate through the pairs to find the key + for _, pair := range pairs { + // Split the pair by colon to get key and value + kv := strings.Split(pair, ":") + if len(kv) == 2 && kv[0] == key { + return kv[1], nil // Return the value if key matches + } + } + + return "", fmt.Errorf("key %s not found in tag", key) +} diff --git a/pkg/db/relation_test.go b/pkg/db/relation_test.go index 3fc67c9f..9e02fcde 100644 --- a/pkg/db/relation_test.go +++ b/pkg/db/relation_test.go @@ -3,183 +3,52 @@ package db import ( "testing" - "github.com/sev-2/raiden/pkg/resource" "github.com/stretchr/testify/assert" ) func TestWith(t *testing.T) { - resource.RegisterModels( - ArticleMockModel{}, - UsersMockModel{}, - TeamsMockModel{}, - OrganizationsMockModel{}, - ) + articleMockModel := ArticleMockModel{} + orderMockModel := OrdersMockModel{} t.Run("match url query for single relation", func(t *testing.T) { - t.Run("without selected columns", func(t *testing.T) { + t.Run("without where condition", func(t *testing.T) { url := NewQuery(&mockRaidenContext). Model(articleMockModel). - With("UsersMockModel", nil). + Preload("User"). GetUrl() - assert.Equal(t, "/rest/v1/articles?select=*,users(*)", url) + assert.Equal(t, "/rest/v1/articles?select=*,user:users!user_id(*)", url) }) - t.Run("with selected columns", func(t *testing.T) { + t.Run("with where condition", func(t *testing.T) { url := NewQuery(&mockRaidenContext). Model(articleMockModel). - With( - "UsersMockModel", - map[string][]string{ - "users": {"id", "username"}, - }, - ). + Preload("User", "status", "eq", "approved"). GetUrl() - assert.Equal(t, "/rest/v1/articles?select=*,users(id,username)", url) - }) - - t.Run("with selected columns and aliases", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel", - map[string][]string{ - "users": {"id", "userid:username"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,userid:username)", url) - }) - }) - - t.Run("match url query for two-nested relation", func(t *testing.T) { - t.Run("without selected columns", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With("UsersMockModel.TeamsMockModel", nil). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(*,teams(*))", url) - }) - - t.Run("with selected columns", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel.TeamsMockModel", - map[string][]string{ - "users": {"id", "username"}, - "teams": {"id", "name"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,username,teams(id,name))", url) - }) - - t.Run("with selected columns and aliases", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel.TeamsMockModel", - map[string][]string{ - "users": {"id", "userid:username"}, - "teams": {"id", "team_name:name"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,userid:username,teams(id,team_name:name))", url) + assert.Equal(t, "/rest/v1/articles?select=*,user:users!user_id(*)&users.status=eq.approved", url) }) }) - t.Run("match url query for three-nested relation", func(t *testing.T) { - t.Run("without selected columns", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With("UsersMockModel.TeamsMockModel.OrganizationsMockModel", nil). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(*,teams(*,organizations(*)))", url) - }) - - t.Run("with selected columns", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel.TeamsMockModel.OrganizationsMockModel", - map[string][]string{ - "users": {"id", "username"}, - "teams": {"id", "name"}, - "organizations": {"id", "name"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,username,teams(id,name,organizations(id,name)))", url) - }) - - t.Run("with selected columns and aliases", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(articleMockModel). - With( - "UsersMockModel.TeamsMockModel.OrganizationsMockModel", - map[string][]string{ - "users": {"id", "userid:username"}, - "teams": {"id", "team_name:name"}, - "organizations": {"id", "org_name:name"}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/articles?select=*,users(id,userid:username,teams(id,team_name:name,organizations(id,org_name:name)))", url) - }) - }) - - t.Run("match url query with foreign key", func(t *testing.T) { - - t.Run("without selected column", func(t *testing.T) { - url := NewQuery(&mockRaidenContext). - Model(OrdersMockModel{}). - With( - "UsersMockModel", - map[string][]string{ - "users!address_id": {}, - }, - ). - GetUrl() - - assert.Equal(t, "/rest/v1/orders?select=*,users!address_id(*)", url) - }) - - t.Run("with all columns", func(t *testing.T) { + t.Run("match url query for multiple relations", func(t *testing.T) { + t.Run("without where condition", func(t *testing.T) { url := NewQuery(&mockRaidenContext). - Model(OrdersMockModel{}). - With( - "UsersMockModel", - map[string][]string{ - "users!address_id": {"*"}, - }, - ). + Model(orderMockModel). + Preload("UserBilling"). + Preload("UserAddress"). GetUrl() - assert.Equal(t, "/rest/v1/orders?select=*,users!address_id(*)", url) + assert.Equal(t, "/rest/v1/orders?select=*,user_billing:users!billing_id(*),user_address:users!address_id(*)", url) }) - t.Run("with selected id", func(t *testing.T) { + t.Run("with where condition", func(t *testing.T) { url := NewQuery(&mockRaidenContext). - Model(OrdersMockModel{}). - With( - "UsersMockModel", - map[string][]string{ - "users!address_id": {"id", "username"}, - }, - ). + Model(orderMockModel). + Preload("UserBilling", "status", "eq", "approved"). + Preload("UserAddress", "status", "eq", "approved"). GetUrl() - assert.Equal(t, "/rest/v1/orders?select=*,users!address_id(id,username)", url) + assert.Equal(t, "/rest/v1/orders?select=*,user_billing:users!billing_id(*),user_address:users!address_id(*)&users.status=eq.approved&users.status=eq.approved", url) }) }) } diff --git a/pkg/db/select.go b/pkg/db/select.go index 77f72e0b..edb255ab 100644 --- a/pkg/db/select.go +++ b/pkg/db/select.go @@ -88,32 +88,3 @@ func isValidColumnName(column string) bool { return isAllowed } - -func isForeignKeyExist(m interface{}, column string) bool { - if column == "inner" { - return true - } - - t := reflect.TypeOf(m) - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - - if tagValue := field.Tag.Get("join"); tagValue != "" { - for _, part := range strings.Split(tagValue, ";") { - kv := strings.SplitN(part, ":", 2) - if len(kv) == 2 && kv[0] == "targetForeign" { - if kv[1] == column { - return true - } - } - } - } - } - - return false -} diff --git a/pkg/db/utils.go b/pkg/db/utils.go index 6e539d47..26c5404d 100644 --- a/pkg/db/utils.go +++ b/pkg/db/utils.go @@ -3,54 +3,12 @@ package db import ( "log" "os" - "reflect" "sort" "strings" "github.com/sev-2/raiden" - "github.com/sev-2/raiden/pkg/resource" ) -func findModel(targetName string) interface{} { - for _, m := range resource.RegisteredModels { - t := reflect.TypeOf(m) - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - if t.Name() == targetName { - return m - } - } - - return nil -} - -func findModelByTable(table string) string { - for _, m := range resource.RegisteredModels { - - t := reflect.TypeOf(m) - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - if field.Name == "Metadata" { - tableName := field.Tag.Get("tableName") - if tableName == table { - parts := strings.Split(t.String(), ".") - return parts[len(parts)-1] - } - } - } - } - - return "" -} - func getConfig() *raiden.Config { currentDir, err := os.Getwd() if err != nil {