diff --git a/middleware/body_decoder.go b/middleware/body_decoder.go index 5577cc0..cec49e9 100644 --- a/middleware/body_decoder.go +++ b/middleware/body_decoder.go @@ -118,56 +118,9 @@ func (m BodyDecoder) Wrap(h handler.Handler) handler.Handler { r.Body = io.NopCloser(bytes.NewReader(buffer.Bytes())) - valOf := reflect.ValueOf(m.BodyPtr).Elem() - if kind := valOf.Kind(); kind == reflect.Struct { - val := reflect.Indirect(valOf) - - for i := 0; i < val.NumField(); i++ { - typeOfParameters := val.Type() - typeOfFieldI := typeOfParameters.Field(i) - - if pattern := typeOfFieldI.Tag.Get("pattern"); pattern != "" { - rex, errC := regexp.Compile(pattern) - if errC != nil { - return nil, errors.InternalServerError("pattern_must_be_regex", "pattern must contain a regular expression") - } - - if !val.Field(i).CanInterface() { - return nil, errors.InternalServerError("interface_failed", "interface cannot be used without panicking") - } - - fieldValue := val.Field(i).Interface() - if !rex.MatchString(fmt.Sprintf("%v", fieldValue)) { - return nil, errors.BadRequest("body_validation_failed", "field %s does not match the required pattern", typeOfFieldI.Name) - } - } - - if typeOfFieldI.Tag.Get("required") != "true" { - continue - } - - if !val.Field(i).IsZero() { - continue - } - - fieldName := typeOfFieldI.Name - switch jsonTag := typeOfFieldI.Tag.Get("json"); jsonTag { - case "-": - return nil, errors.InternalServerError("invalid_config", "field '%s' is required but json tag value is '-'", fieldName) - - case "": - return nil, errors.BadRequest("missing_param", "field %s is required", fieldName) - - default: - parts := strings.Split(jsonTag, ",") - name := parts[0] - if name == "" { - name = fieldName - } - - return nil, errors.BadRequest("missing_param", "field %s is required", name) - } - } + err = validateRecursive(m.BodyPtr) + if err != nil { + return nil, err } if v, ok := m.BodyPtr.(BodyValidation); !m.SkipValidation && ok { @@ -229,3 +182,94 @@ func (m BodyDecoder) resolveContentType(r *http.Request) (Decoder, error) { return result, nil } + +func validateStruct(val reflect.Value) errors.Error { + for i := 0; i < val.NumField(); i++ { + typeOfParameters := val.Type() + typeOfFieldI := typeOfParameters.Field(i) + + if !val.Field(i).CanInterface() { + return errors.InternalServerError("can_interface_failed", "interface for field %s cannot be used without panicking", typeOfFieldI.Name) + } + + // pattern validation + if pattern := typeOfFieldI.Tag.Get("pattern"); pattern != "" { + rex, errC := regexp.Compile(pattern) + if errC != nil { + return errors.InternalServerError("pattern_must_be_regex", "pattern must contain a regular expression") + } + + fieldValue := val.Field(i).Interface() + if !rex.MatchString(fmt.Sprintf("%v", fieldValue)) { + return errors.BadRequest("body_validation_failed", "field %s does not match the required pattern", typeOfFieldI.Name) + } + } + + // Enum validation + enumValues := typeOfFieldI.Tag.Get("enum") + if enumValues != "" { + enumList := strings.Split(enumValues, ",") + fieldValue := fmt.Sprintf("%v", val.Field(i).Interface()) + enumValid := false + for _, enum := range enumList { + if fieldValue == enum { + enumValid = true + break + } + } + if !enumValid { + return errors.BadRequest("enum_validation_failed", "field %s must be one of [%s]", typeOfFieldI.Name, enumValues) + } + } + + // required field validation + if typeOfFieldI.Tag.Get("required") != "true" { + continue + } + + if !val.Field(i).IsZero() { + continue + } + + fieldName := typeOfFieldI.Name + switch jsonTag := typeOfFieldI.Tag.Get("json"); jsonTag { + case "-": + return errors.InternalServerError("invalid_config", "field '%s' is required but json tag value is '-'", fieldName) + + case "": + return errors.BadRequest("missing_param", "field %s is required", fieldName) + + default: + parts := strings.Split(jsonTag, ",") + name := parts[0] + if name == "" { + name = fieldName + } + + return errors.BadRequest("missing_param", "field %s is required", name) + } + } + return nil +} + +func validateRecursive(m interface{}) errors.Error { + valOf := reflect.ValueOf(m).Elem() + + if kind := valOf.Kind(); kind == reflect.Struct { + val := reflect.Indirect(valOf) + if err := validateStruct(val); err != nil { + return err + } + + // Recursively validate nested structs + for i := 0; i < val.NumField(); i++ { + if val.Field(i).Kind() == reflect.Struct { + if err := validateRecursive(val.Field(i).Addr().Interface()); err != nil { + return err + } + } + } + } + + return nil +}