diff --git a/impl.go b/impl.go index 0a9906b..17b16dc 100644 --- a/impl.go +++ b/impl.go @@ -125,26 +125,20 @@ type Pkg struct { recvPkg string } -// Spec is ast.TypeSpec with the associated comment map. -type Spec struct { - *ast.TypeSpec - ast.CommentMap -} - // typeSpec locates the *ast.TypeSpec for type id in the import path. -func typeSpec(path string, id string, srcDir string) (Pkg, Spec, error) { +func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error) { var pkg *build.Package var err error if path == "" { pkg, err = build.ImportDir(srcDir, 0) if err != nil { - return Pkg{}, Spec{}, fmt.Errorf("couldn't find package in %s: %v", srcDir, err) + return Pkg{}, nil, fmt.Errorf("couldn't find package in %s: %v", srcDir, err) } } else { pkg, err = build.Import(path, srcDir, 0) if err != nil { - return Pkg{}, Spec{}, fmt.Errorf("couldn't find package %s: %v", path, err) + return Pkg{}, nil, fmt.Errorf("couldn't find package %s: %v", path, err) } } @@ -158,8 +152,6 @@ func typeSpec(path string, id string, srcDir string) (Pkg, Spec, error) { continue } - cmap := ast.NewCommentMap(fset, f, f.Comments) - for _, decl := range f.Decls { decl, ok := decl.(*ast.GenDecl) if !ok || decl.Tok != token.TYPE { @@ -171,12 +163,11 @@ func typeSpec(path string, id string, srcDir string) (Pkg, Spec, error) { continue } p := Pkg{Package: pkg, FileSet: fset} - s := Spec{TypeSpec: spec, CommentMap: cmap.Filter(decl)} - return p, s, nil + return p, spec, nil } } } - return Pkg{}, Spec{}, fmt.Errorf("type %s not found in %s", id, path) + return Pkg{}, nil, fmt.Errorf("type %s not found in %s", id, path) } // gofmt pretty-prints e. @@ -253,7 +244,7 @@ const ( WithoutComments EmitComments = false ) -func (p Pkg) funcsig(f *ast.Field, cmap ast.CommentMap, comments EmitComments) Func { +func (p Pkg) funcsig(f *ast.Field, comments EmitComments) Func { fn := Func{Name: f.Names[0].Name} typ := f.Type.(*ast.FuncType) if typ.Params != nil { @@ -273,8 +264,8 @@ func (p Pkg) funcsig(f *ast.Field, cmap ast.CommentMap, comments EmitComments) F fn.Res = append(fn.Res, p.params(field)...) } } - if commentsBefore(f, cmap.Comments()) && comments == WithComments { - fn.Comments = flattenCommentMap(cmap) + if comments == WithComments && f.Doc != nil { + fn.Comments = flattenDocComment(f) } return fn } @@ -328,7 +319,7 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error) continue } - fn := p.funcsig(fndecl, spec.CommentMap.Filter(fndecl), comments) + fn := p.funcsig(fndecl, comments) fns = append(fns, fn) } return fns, nil @@ -394,31 +385,14 @@ func validReceiver(recv string) bool { return err == nil } -// commentsBefore reports whether commentGroups precedes a field. -func commentsBefore(field *ast.Field, cg []*ast.CommentGroup) bool { - if len(cg) > 0 { - return cg[0].Pos() < field.Pos() - } - return false -} - -// flattenCommentMap flattens the comment map to a string. -// This function must be used at the point when m is expected to have a single -// element. -func flattenCommentMap(m ast.CommentMap) string { - if len(m) != 1 { - panic("flattenCommentMap expects comment map of length 1") - } +// flattenDocComment flattens the field doc comments to a string +func flattenDocComment(f *ast.Field) string { var result strings.Builder - for _, cgs := range m { - for _, cg := range cgs { - for _, c := range cg.List { - result.WriteString(c.Text) - // add an end-of-line character if this is '//'-style comment - if c.Text[1] == '/' { - result.WriteString("\n") - } - } + for _, c := range f.Doc.List { + result.WriteString(c.Text) + // add an end-of-line character if this is '//'-style comment + if c.Text[1] == '/' { + result.WriteString("\n") } } diff --git a/impl_test.go b/impl_test.go index 2c8e6a9..ffed4e4 100644 --- a/impl_test.go +++ b/impl_test.go @@ -77,8 +77,8 @@ func TestTypeSpec(t *testing.T) { if reflect.DeepEqual(p, Pkg{}) { t.Errorf("typeSpec(%q, %q).pkg=Pkg{} want non-nil", tt.path, tt.id) } - if reflect.DeepEqual(spec, Spec{}) { - t.Errorf("typeSpec(%q, %q).spec=Spec{} want non-nil", tt.path, tt.id) + if spec == nil { + t.Errorf("typeSpec(%q, %q).spec=nil want non-nil", tt.path, tt.id) } } }