diff --git a/desc/export_test.go b/desc/export_test.go index a33d1984..2d792b2c 100644 --- a/desc/export_test.go +++ b/desc/export_test.go @@ -7,6 +7,6 @@ var ( NewStub = newStub ) -func (d *Stub) AddDTO(mTyp reflect.Type) error { - return d.addDTO(mTyp) +func (d *Stub) AddDTO(mTyp reflect.Type, isErr bool) error { + return d.addDTO(mTyp, isErr) } diff --git a/desc/service.go b/desc/service.go index 9bfa3346..194828e5 100644 --- a/desc/service.go +++ b/desc/service.go @@ -166,26 +166,26 @@ func (s Service) Stub(pkgName string, tags ...string) (*Stub, error) { func (s Service) dtoStub(stub *Stub) error { for _, c := range s.Contracts { - err := stub.addDTO(reflect.TypeOf(c.Input)) + err := stub.addDTO(reflect.TypeOf(c.Input), false) if err != nil { return err } if c.Output != nil { - err = stub.addDTO(reflect.TypeOf(c.Output)) + err = stub.addDTO(reflect.TypeOf(c.Output), false) if err != nil { return err } } for _, pe := range s.PossibleErrors { - err = stub.addDTO(reflect.TypeOf(pe.Message)) + err = stub.addDTO(reflect.TypeOf(pe.Message), true) if err != nil { return err } } for _, pe := range c.PossibleErrors { - err = stub.addDTO(reflect.TypeOf(pe.Message)) + err = stub.addDTO(reflect.TypeOf(pe.Message), true) if err != nil { return err } @@ -207,8 +207,10 @@ func (s Service) rpcStub( if dto, ok := stub.getDTO(reflect.TypeOf(c.Input)); ok { m.Request = dto } - if dto, ok := stub.getDTO(reflect.TypeOf(c.Output)); ok { - m.Response = dto + if c.Output != nil { + if dto, ok := stub.getDTO(reflect.TypeOf(c.Output)); ok { + m.Response = dto + } } var possibleErrors []Error @@ -216,8 +218,7 @@ func (s Service) rpcStub( possibleErrors = append(possibleErrors, c.PossibleErrors...) for _, e := range possibleErrors { if dto, ok := stub.getDTO(reflect.TypeOf(e.Message)); ok { - m.PossibleErrors = append( - m.PossibleErrors, + m.addPossibleError( ErrorDTO{ Code: e.Code, Item: e.Item, @@ -241,8 +242,10 @@ func (s Service) restStub( if dto, ok := stub.getDTO(reflect.TypeOf(c.Input)); ok { m.Request = dto } - if dto, ok := stub.getDTO(reflect.TypeOf(c.Output)); ok { - m.Response = dto + if c.Output != nil { + if dto, ok := stub.getDTO(reflect.TypeOf(c.Output)); ok { + m.Response = dto + } } var possibleErrors []Error @@ -250,8 +253,7 @@ func (s Service) restStub( possibleErrors = append(possibleErrors, c.PossibleErrors...) for _, e := range possibleErrors { if dto, ok := stub.getDTO(reflect.TypeOf(e.Message)); ok { - m.PossibleErrors = append( - m.PossibleErrors, + m.addPossibleError( ErrorDTO{ Code: e.Code, Item: e.Item, diff --git a/desc/stub.go b/desc/stub.go index 4d898258..0b0a913b 100644 --- a/desc/stub.go +++ b/desc/stub.go @@ -3,6 +3,7 @@ package desc import ( "fmt" "reflect" + "strings" "github.com/clubpay/ronykit" ) @@ -12,9 +13,46 @@ type DTO struct { Comments []string Name string Type string + IsErr bool Fields []DTOField } +func (dto DTO) CodeField() string { + var fn string + for _, f := range dto.Fields { + x := strings.ToLower(f.Name) + if f.Type != "int" { + continue + } + if x == "code" { + return f.Name + } + if strings.HasPrefix(f.Name, "code") { + fn = f.Name + } + } + + return fn +} + +func (dto DTO) ItemField() string { + var fn string + for _, f := range dto.Fields { + x := strings.ToLower(f.Name) + if f.Type != "string" { + continue + } + if x == "item" || x == "items" { + return f.Name + } + if strings.HasPrefix(f.Name, "item") { + fn = f.Name + } + } + + return fn +} + // DTOField represents description of a field of the DTO type DTOField struct { Name string @@ -49,6 +87,15 @@ type RESTMethod struct { PossibleErrors []ErrorDTO } +func (rm *RESTMethod) addPossibleError(dto ErrorDTO) { + for _, e := range rm.PossibleErrors { + if e.Code == dto.Code { + return + } + } + rm.PossibleErrors = append(rm.PossibleErrors, dto) +} + // RPCMethod represents description of a Contract with ronykit.RPCRouteSelector type RPCMethod struct { Name string @@ -61,6 +108,15 @@ type RPCMethod struct { ronykit.OutgoingRPCContainer } +func (rm *RPCMethod) addPossibleError(dto ErrorDTO) { + for _, e := range rm.PossibleErrors { + if e.Code == dto.Code { + return + } + } + rm.PossibleErrors = append(rm.PossibleErrors, dto) +} + // Stub represents description of a stub of the service described by Service descriptor. type Stub struct { tags []string @@ -78,8 +134,10 @@ func newStub(tags ...string) *Stub { } } -func (d *Stub) addDTO(mTyp reflect.Type) error { - dto := DTO{} +func (d *Stub) addDTO(mTyp reflect.Type, isErr bool) error { + dto := DTO{ + IsErr: isErr, + } if mTyp.Kind() == reflect.Ptr { mTyp = mTyp.Elem() } @@ -93,12 +151,12 @@ func (d *Stub) addDTO(mTyp reflect.Type) error { switch { case k == reflect.Struct: - err := d.addDTO(ft.Type) + err := d.addDTO(ft.Type, false) if err != nil { return err } case k == reflect.Ptr && ft.Type.Elem().Kind() == reflect.Struct: - err := d.addDTO(ft.Type.Elem()) + err := d.addDTO(ft.Type.Elem(), false) if err != nil { return err } diff --git a/desc/stub_test.go b/desc/stub_test.go index 1c127179..052ef5a1 100644 --- a/desc/stub_test.go +++ b/desc/stub_test.go @@ -12,7 +12,7 @@ var _ = Describe("Desc", func() { It("should detect all DTOs", func() { d := desc.NewStub("json") - Expect(d.AddDTO(reflect.TypeOf(&customStruct{}))).To(Succeed()) + Expect(d.AddDTO(reflect.TypeOf(&customStruct{}), false)).To(Succeed()) Expect(d.DTOs).To(HaveLen(2)) Expect(d.DTOs["customSubStruct"].Fields).To(HaveLen(2)) Expect(d.DTOs["customSubStruct"].Fields[0].Name).To(Equal("SubParam1")) diff --git a/exmples/simple-rest-server/stub/sampleservice.go b/exmples/simple-rest-server/stub/sampleservice.go index 6e395655..2a588888 100755 --- a/exmples/simple-rest-server/stub/sampleservice.go +++ b/exmples/simple-rest-server/stub/sampleservice.go @@ -44,6 +44,13 @@ type ErrorMessage struct { Item string `json:"item"` } +func (x ErrorMessage) GetCode() int { + return x.Code +} +func (x ErrorMessage) GetItem() string { + return x.Item +} + // RedirectRequest is a data transfer object type RedirectRequest struct { URL string `json:"url"` @@ -92,7 +99,7 @@ func (s SampleServiceStub) Echo(ctx context.Context, req *EchoRequest) (*EchoRes return err } - return stub.NewErrorWithMsg(400, "INPUT", res) + return stub.NewErrorWithMsg(res) }, ). DefaultResponseHandler( @@ -123,7 +130,7 @@ func (s SampleServiceStub) Sum1(ctx context.Context, req *SumRequest) (*SumRespo return err } - return stub.NewErrorWithMsg(400, "INPUT", res) + return stub.NewErrorWithMsg(res) }, ). DefaultResponseHandler( @@ -154,7 +161,7 @@ func (s SampleServiceStub) Sum2(ctx context.Context, req *SumRequest) (*SumRespo return err } - return stub.NewErrorWithMsg(400, "INPUT", res) + return stub.NewErrorWithMsg(res) }, ). DefaultResponseHandler( @@ -185,7 +192,7 @@ func (s SampleServiceStub) SumRedirect(ctx context.Context, req *SumRequest) (*S return err } - return stub.NewErrorWithMsg(400, "INPUT", res) + return stub.NewErrorWithMsg(res) }, ). DefaultResponseHandler( diff --git a/internal/tpl/go/stub.gotmpl b/internal/tpl/go/stub.gotmpl index 6924c1f5..f79e70b9 100644 --- a/internal/tpl/go/stub.gotmpl +++ b/internal/tpl/go/stub.gotmpl @@ -16,6 +16,18 @@ {{- end }} {{- end }} } + {{ if .IsErr }} + {{- if ne .CodeField ""}} + func (x {{.Name}}) GetCode() int { + return x.{{.CodeField}} + } + {{- end }} + {{- if ne .ItemField ""}} + func (x {{.Name}}) GetItem() string { + return x.{{.ItemField}} + } + {{- end }} + {{- end }} {{ end }} // Code generated by RonyKIT Stub Generator (Golang); DO NOT EDIT. @@ -75,7 +87,7 @@ func (s {{$serviceName}}Stub) {{$methodName}}(ctx context.Context, req *{{.Reque return err } - return stub.NewErrorWithMsg({{$errDto.Code}}, "{{$errDto.Item}}", res) + return stub.NewErrorWithMsg(res) }, ). {{- end }} diff --git a/stub/error.go b/stub/error.go index 5da98644..452f1018 100644 --- a/stub/error.go +++ b/stub/error.go @@ -20,12 +20,18 @@ func NewError(code int, item string) *Error { } } -func NewErrorWithMsg(code int, item string, msg ronykit.Message) *Error { - return &Error{ - code: code, - item: item, - msg: msg, +func NewErrorWithMsg(msg ronykit.Message) *Error { + wErr := &Error{ + msg: msg, } + if e, ok := msg.(interface{ GetCode() int }); ok { + wErr.code = e.GetCode() + } + if e, ok := msg.(interface{ GetItem() string }); ok { + wErr.item = e.GetItem() + } + + return wErr } func WrapError(err error) *Error {