Skip to content

Commit

Permalink
[gapi] BodyDecoder validate nested struct
Browse files Browse the repository at this point in the history
  * validate enum
  • Loading branch information
rach-ba committed Sep 28, 2023
1 parent 2cb96ee commit dcb38ba
Showing 1 changed file with 94 additions and 50 deletions.
144 changes: 94 additions & 50 deletions middleware/body_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,56 +118,9 @@ func (m BodyDecoder) Wrap(h handler.Handler) handler.Handler {

r.Body = io.NopCloser(bytes.NewReader(buffer.Bytes()))

valOf := reflect.ValueOf(m.BodyPtr).Elem()
if kind := valOf.Kind(); kind == reflect.Struct {
val := reflect.Indirect(valOf)

for i := 0; i < val.NumField(); i++ {
typeOfParameters := val.Type()
typeOfFieldI := typeOfParameters.Field(i)

if pattern := typeOfFieldI.Tag.Get("pattern"); pattern != "" {
rex, errC := regexp.Compile(pattern)
if errC != nil {
return nil, errors.InternalServerError("pattern_must_be_regex", "pattern must contain a regular expression")
}

if !val.Field(i).CanInterface() {
return nil, errors.InternalServerError("interface_failed", "interface cannot be used without panicking")
}

fieldValue := val.Field(i).Interface()
if !rex.MatchString(fmt.Sprintf("%v", fieldValue)) {
return nil, errors.BadRequest("body_validation_failed", "field %s does not match the required pattern", typeOfFieldI.Name)
}
}

if typeOfFieldI.Tag.Get("required") != "true" {
continue
}

if !val.Field(i).IsZero() {
continue
}

fieldName := typeOfFieldI.Name
switch jsonTag := typeOfFieldI.Tag.Get("json"); jsonTag {
case "-":
return nil, errors.InternalServerError("invalid_config", "field '%s' is required but json tag value is '-'", fieldName)

case "":
return nil, errors.BadRequest("missing_param", "field %s is required", fieldName)

default:
parts := strings.Split(jsonTag, ",")
name := parts[0]
if name == "" {
name = fieldName
}

return nil, errors.BadRequest("missing_param", "field %s is required", name)
}
}
err = validateRecursive(m.BodyPtr)
if err != nil {
return nil, err
}

if v, ok := m.BodyPtr.(BodyValidation); !m.SkipValidation && ok {
Expand Down Expand Up @@ -229,3 +182,94 @@ func (m BodyDecoder) resolveContentType(r *http.Request) (Decoder, error) {

return result, nil
}

func validateStruct(val reflect.Value) errors.Error {
for i := 0; i < val.NumField(); i++ {
typeOfParameters := val.Type()
typeOfFieldI := typeOfParameters.Field(i)

if !val.Field(i).CanInterface() {
return errors.InternalServerError("can_interface_failed", "interface for field %s cannot be used without panicking", typeOfFieldI.Name)
}

// pattern validation
if pattern := typeOfFieldI.Tag.Get("pattern"); pattern != "" {
rex, errC := regexp.Compile(pattern)
if errC != nil {
return errors.InternalServerError("pattern_must_be_regex", "pattern must contain a regular expression")
}

fieldValue := val.Field(i).Interface()
if !rex.MatchString(fmt.Sprintf("%v", fieldValue)) {
return errors.BadRequest("body_validation_failed", "field %s does not match the required pattern", typeOfFieldI.Name)
}
}

// Enum validation
enumValues := typeOfFieldI.Tag.Get("enum")
if enumValues != "" {
enumList := strings.Split(enumValues, ",")
fieldValue := fmt.Sprintf("%v", val.Field(i).Interface())
enumValid := false
for _, enum := range enumList {
if fieldValue == enum {
enumValid = true
break
}
}
if !enumValid {
return errors.BadRequest("enum_validation_failed", "field %s must be one of [%s]", typeOfFieldI.Name, enumValues)
}
}

// required field validation
if typeOfFieldI.Tag.Get("required") != "true" {
continue
}

if !val.Field(i).IsZero() {
continue
}

fieldName := typeOfFieldI.Name
switch jsonTag := typeOfFieldI.Tag.Get("json"); jsonTag {
case "-":
return errors.InternalServerError("invalid_config", "field '%s' is required but json tag value is '-'", fieldName)

case "":
return errors.BadRequest("missing_param", "field %s is required", fieldName)

default:
parts := strings.Split(jsonTag, ",")
name := parts[0]
if name == "" {
name = fieldName
}

return errors.BadRequest("missing_param", "field %s is required", name)
}
}
return nil
}

func validateRecursive(m interface{}) errors.Error {
valOf := reflect.ValueOf(m).Elem()

if kind := valOf.Kind(); kind == reflect.Struct {
val := reflect.Indirect(valOf)
if err := validateStruct(val); err != nil {
return err
}

// Recursively validate nested structs
for i := 0; i < val.NumField(); i++ {
if val.Field(i).Kind() == reflect.Struct {
if err := validateRecursive(val.Field(i).Addr().Interface()); err != nil {
return err
}
}
}
}

return nil
}

0 comments on commit dcb38ba

Please sign in to comment.