Skip to content

Commit

Permalink
Fix dive only (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
deankarn authored Dec 24, 2019
1 parent cad431a commit d9c7865
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 64 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package mold
============
![Project status](https://img.shields.io/badge/version-3.0.0-green.svg)
![Project status](https://img.shields.io/badge/version-3.0.1-green.svg)
[![Build Status](https://travis-ci.org/go-playground/mold.svg?branch=v2)](https://travis-ci.org/go-playground/mold)
[![Coverage Status](https://coveralls.io/repos/github/go-playground/mold/badge.svg?branch=v2)](https://coveralls.io/github/go-playground/mold?branch=v2)
[![Go Report Card](https://goreportcard.com/badge/github.com/go-playground/mold)](https://goreportcard.com/report/github.com/go-playground/mold)
Expand Down
2 changes: 1 addition & 1 deletion _examples/full/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/go-playground/mold/v3/modifiers"
"github.com/go-playground/mold/v3/scrubbers"

"gopkg.in/go-playground/validator.v9"
"github.com/go-playground/validator/v10"
)

// This example is centered around a form post, but doesn't have to be
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
github.com/go-playground/assert v1.2.1 h1:ad06XqC+TOv0nJWnbULSlh3ehp5uLuQEojZY5Tq8RgI=
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/segmentio/go-snakecase v1.2.0 h1:4cTmEjPGi03WmyAHWBjX53viTpBkn/z+4DO++fqYvpw=
Expand Down
119 changes: 58 additions & 61 deletions mold.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

var (
timeType = reflect.TypeOf(time.Time{})
defaultCField = &cField{}
restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
)
Expand Down Expand Up @@ -76,7 +75,6 @@ func (t *Transformer) Register(tag string, fn Func) {
if ok || strings.ContainsAny(tag, restrictedTagChars) {
panic(fmt.Sprintf(restrictedTagErr, tag))
}

t.transformations[tag] = fn
}

Expand Down Expand Up @@ -132,11 +130,10 @@ func (t *Transformer) Struct(ctx context.Context, v interface{}) error {
if val.Kind() != reflect.Struct || val.Type() == timeType {
return &ErrInvalidTransformation{typ: reflect.TypeOf(v)}
}

return t.setByStruct(ctx, val, typ, nil)
return t.setByStruct(ctx, val, typ)
}

func (t *Transformer) setByStruct(ctx context.Context, current reflect.Value, typ reflect.Type, ct *cTag) (err error) {
func (t *Transformer) setByStruct(ctx context.Context, current reflect.Value, typ reflect.Type) (err error) {
cs, ok := t.cCache.Get(typ)
if !ok {
if cs, err = t.extractStructCache(current); err != nil {
Expand All @@ -155,7 +152,7 @@ func (t *Transformer) setByStruct(ctx context.Context, current reflect.Value, ty

for i := 0; i < len(cs.fields); i++ {
f = cs.fields[i]
if err = t.setByField(ctx, current.Field(f.idx), f, f.cTags); err != nil {
if err = t.setByField(ctx, current.Field(f.idx), f.cTags); err != nil {
return
}
}
Expand Down Expand Up @@ -192,19 +189,15 @@ func (t *Transformer) Field(ctx context.Context, v interface{}, tags string) (er
}
t.tCache.lock.Unlock()
}
err = t.setByField(ctx, val, defaultCField, ctag)
err = t.setByField(ctx, val, ctag)
return
}

func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, cf *cField, ct *cTag) (err error) {
func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, ct *cTag) (err error) {
current, kind := extractType(orig)

if ct.hasTag {
for {
if ct == nil {
break
}

if ct != nil && ct.hasTag {
for ct != nil {
switch ct.typeof {
case typeEndKeys:
return
Expand All @@ -213,52 +206,9 @@ func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, cf *cF

switch kind {
case reflect.Slice, reflect.Array:
reusableCF := &cField{}

for i := 0; i < current.Len(); i++ {
if err = t.setByField(ctx, current.Index(i), reusableCF, ct); err != nil {
return
}
}

err = t.setByIterable(ctx, current, ct)
case reflect.Map:
reusableCF := &cField{}

hasKeys := ct != nil && ct.typeof == typeKeys && ct.keys != nil

for _, key := range current.MapKeys() {
newVal := reflect.New(current.Type().Elem()).Elem()
newVal.Set(current.MapIndex(key))

if hasKeys {

// remove current map key as we may be changing it
// and re-add to the map afterwards
current.SetMapIndex(key, reflect.Value{})

newKey := reflect.New(current.Type().Key()).Elem()
newKey.Set(key)
key = newKey

// handle map key
if err = t.setByField(ctx, key, reusableCF, ct.keys); err != nil {
return
}

// can be nil when just keys being validated
if ct.next != nil {
if err = t.setByField(ctx, newVal, reusableCF, ct.next); err != nil {
return
}
}
} else {
if err = t.setByField(ctx, newVal, reusableCF, ct); err != nil {
return
}
}
current.SetMapIndex(key, newVal)
}

err = t.setByMap(ctx, current, ct)
default:
err = ErrInvalidDive
}
Expand Down Expand Up @@ -300,13 +250,60 @@ func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, cf *cF
newVal := reflect.New(typ).Elem()
newVal.Set(current)

if err = t.setByStruct(ctx, newVal, typ, ct); err != nil {
if err = t.setByStruct(ctx, newVal, typ); err != nil {
return
}
orig.Set(newVal)
return
}
err = t.setByStruct(ctx, current, typ, ct)
err = t.setByStruct(ctx, current, typ)
}
return
}

func (t *Transformer) setByIterable(ctx context.Context, current reflect.Value, ct *cTag) (err error) {
for i := 0; i < current.Len(); i++ {
if err = t.setByField(ctx, current.Index(i), ct); err != nil {
return
}
}
return
}

func (t *Transformer) setByMap(ctx context.Context, current reflect.Value, ct *cTag) error {
hasKeys := ct != nil && ct.typeof == typeKeys && ct.keys != nil

for _, key := range current.MapKeys() {
newVal := reflect.New(current.Type().Elem()).Elem()
newVal.Set(current.MapIndex(key))

if hasKeys {
// remove current map key as we may be changing it
// and re-add to the map afterwards
current.SetMapIndex(key, reflect.Value{})

newKey := reflect.New(current.Type().Key()).Elem()
newKey.Set(key)
key = newKey

// handle map key
if err := t.setByField(ctx, key, ct.keys); err != nil {
return err
}

// can be nil when just keys being validated
if ct.next != nil {
if err := t.setByField(ctx, newVal, ct.next); err != nil {
return err
}
}
} else {
if err := t.setByField(ctx, newVal, ct); err != nil {
return err
}
}
current.SetMapIndex(key, newVal)
}

return nil
}
78 changes: 78 additions & 0 deletions mold_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,81 @@ func TestDiveKeys(t *testing.T) {
err = set.Field(context.Background(), &m, "dive,keys,default,endkeys,err")
NotEqual(t, err, nil)
}

func TestStructArray(t *testing.T) {
type InnerStruct struct {
String string `s:"defaultStr"`
}

type Test struct {
Inner InnerStruct
Arr []InnerStruct `s:"defaultArr"`
ArrDive []InnerStruct `s:"defaultArr,dive"`
ArrNoTag []InnerStruct
}

set := New()
set.SetTagName("s")
set.Register("defaultArr", func(ctx context.Context, t *Transformer, value reflect.Value, param string) error {
if HasValue(value) {
return nil
}
value.Set(reflect.MakeSlice(value.Type(), 2, 2))
return nil
})
set.Register("defaultStr", func(ctx context.Context, t *Transformer, value reflect.Value, param string) error {
if value.String() == "ok" {
return errors.New("ALREADY OK")
}
value.SetString("default")
return nil
})

var tt Test

err := set.Struct(context.Background(), &tt)
Equal(t, err, nil)
Equal(t, len(tt.Arr), 2)
Equal(t, len(tt.ArrDive), 2)
Equal(t, tt.Arr[0].String, "")
Equal(t, tt.Arr[1].String, "")
Equal(t, tt.ArrDive[0].String, "default")
Equal(t, tt.ArrDive[1].String, "default")

Equal(t, tt.Inner.String, "default")

tt2 := Test{
Arr: make([]InnerStruct, 1),
}

err = set.Struct(context.Background(), &tt2)
Equal(t, err, nil)
Equal(t, len(tt2.Arr), 1)
Equal(t, tt2.Arr[0].String, "")

tt3 := Test{
Arr: []InnerStruct{{"ok"}},
}

err = set.Struct(context.Background(), &tt3)
Equal(t, err, nil)
Equal(t, len(tt3.Arr), 1)
Equal(t, tt3.Arr[0].String, "ok")

tt4 := Test{
ArrDive: []InnerStruct{{"ok"}},
}

err = set.Struct(context.Background(), &tt4)
NotEqual(t, err, nil)
Equal(t, err.Error(), "ALREADY OK")

tt5 := Test{
ArrNoTag: make([]InnerStruct, 1),
}

err = set.Struct(context.Background(), &tt5)
Equal(t, err, nil)
Equal(t, len(tt5.ArrNoTag), 1)
Equal(t, tt5.ArrNoTag[0].String, "")
}

0 comments on commit d9c7865

Please sign in to comment.