Skip to content

Commit

Permalink
fix: Support for multiple structs in one file (#2)
Browse files Browse the repository at this point in the history
ginokent authored Nov 10, 2023

Verified

This commit was signed with the committer’s verified signature.
snyk-bot Snyk bot
2 parents 7ef5d77 + 3288775 commit ba2cb2c
Showing 15 changed files with 388 additions and 111 deletions.
44 changes: 23 additions & 21 deletions internal/arcgen/lang/go/dump_source.go
Original file line number Diff line number Diff line change
@@ -9,27 +9,29 @@ import (
"github.com/kunitsucom/arcgen/internal/logs"
)

func dumpSource(fset *token.FileSet, arcSrcSet ARCSourceSet) {
for _, arcSrc := range arcSrcSet {
logs.Trace.Print("== Source ================================================================================================================================")
_, _ = io.WriteString(logs.Trace.LineWriter("r.CommentGroup.Text: "), arcSrc.CommentGroup.Text())
logs.Trace.Print("-- CommentGroup --------------------------------------------------------------------------------------------------------------------------------")
{
commentGroupAST := bytes.NewBuffer(nil)
goast.Fprint(commentGroupAST, fset, arcSrc.CommentGroup, goast.NotNilFilter)
_, _ = logs.Trace.LineWriter("").Write(commentGroupAST.Bytes())
}
logs.Trace.Print("-- TypeSpec --------------------------------------------------------------------------------------------------------------------------------")
{
typeSpecAST := bytes.NewBuffer(nil)
goast.Fprint(typeSpecAST, fset, arcSrc.TypeSpec, goast.NotNilFilter)
_, _ = logs.Trace.LineWriter("").Write(typeSpecAST.Bytes())
}
logs.Trace.Print("-- StructType --------------------------------------------------------------------------------------------------------------------------------")
{
structTypeAST := bytes.NewBuffer(nil)
goast.Fprint(structTypeAST, fset, arcSrc.StructType, goast.NotNilFilter)
_, _ = logs.Trace.LineWriter("").Write(structTypeAST.Bytes())
func dumpSource(fset *token.FileSet, arcSrcSet *ARCSourceSet) {
if arcSrcSet != nil {
for _, arcSrc := range arcSrcSet.ARCSources {
logs.Trace.Print("== Source ================================================================================================================================")
_, _ = io.WriteString(logs.Trace.LineWriter("r.CommentGroup.Text: "), arcSrc.CommentGroup.Text())
logs.Trace.Print("-- CommentGroup --------------------------------------------------------------------------------------------------------------------------------")
{
commentGroupAST := bytes.NewBuffer(nil)
goast.Fprint(commentGroupAST, fset, arcSrc.CommentGroup, goast.NotNilFilter)
_, _ = logs.Trace.LineWriter("").Write(commentGroupAST.Bytes())
}
logs.Trace.Print("-- TypeSpec --------------------------------------------------------------------------------------------------------------------------------")
{
typeSpecAST := bytes.NewBuffer(nil)
goast.Fprint(typeSpecAST, fset, arcSrc.TypeSpec, goast.NotNilFilter)
_, _ = logs.Trace.LineWriter("").Write(typeSpecAST.Bytes())
}
logs.Trace.Print("-- StructType --------------------------------------------------------------------------------------------------------------------------------")
{
structTypeAST := bytes.NewBuffer(nil)
goast.Fprint(structTypeAST, fset, arcSrc.StructType, goast.NotNilFilter)
_, _ = logs.Trace.LineWriter("").Write(structTypeAST.Bytes())
}
}
}
}
69 changes: 56 additions & 13 deletions internal/arcgen/lang/go/extract_source.go
Original file line number Diff line number Diff line change
@@ -15,31 +15,63 @@ import (
apperr "github.com/kunitsucom/arcgen/pkg/errors"
)

//nolint:gocognit,cyclop
func extractSource(_ context.Context, fset *token.FileSet, f *goast.File) (ARCSourceSet, error) {
arcSrcSet := make(ARCSourceSet, 0)
//nolint:cyclop,funlen,gocognit
func extractSource(_ context.Context, fset *token.FileSet, f *goast.File) (*ARCSourceSet, error) {
// NOTE: Use map to avoid duplicate entries.
arcSrcMap := make(map[string]*ARCSource)

goast.Inspect(f, func(node goast.Node) bool {
switch n := node.(type) {
case *goast.TypeSpec:
typeSpec := n
switch t := n.Type.(type) {
case *goast.StructType:
structType := t
if hasColumnTagGo(t) {
pos := fset.Position(structType.Pos())
logs.Debug.Printf("🔍: %s: type=%s", pos.String(), n.Name.Name)
arcSrcMap[pos.String()] = &ARCSource{
Source: pos,
Package: f.Name,
TypeSpec: typeSpec,
StructType: structType,
}
}
return false
default: // noop
}
default: // noop
}
return true
})

// Since it is not possible to extract the comment group associated with the position of struct,
// search for the struct associated with the comment group and overwrite it.
for commentedNode, commentGroups := range goast.NewCommentMap(fset, f, f.Comments) {
for _, commentGroup := range commentGroups {
CommentGroupLoop:
for _, commentLine := range commentGroup.List {
commentGroup := commentGroup // MEMO: Using the variable on range scope `commentGroup` in function literal (scopelint)
logs.Trace.Printf("commentLine=%s: %s", filepathz.Short(fset.Position(commentGroup.Pos()).String()), commentLine.Text)
// NOTE: If the comment line matches the ColumnTagGo, it is assumed to be a comment line for the struct.
if matches := ColumnTagGoCommentLineRegex().FindStringSubmatch(commentLine.Text); len(matches) > _ColumnTagGoCommentLineRegexContentIndex {
s := &ARCSource{
Position: fset.Position(commentLine.Pos()),
Package: f.Name,
CommentGroup: commentGroup,
}
goast.Inspect(commentedNode, func(node goast.Node) bool {
switch n := node.(type) {
case *goast.TypeSpec:
s.TypeSpec = n
typeSpec := n
switch t := n.Type.(type) {
case *goast.StructType:
s.StructType = t
structType := t
if hasColumnTagGo(t) {
logs.Debug.Printf("🔍: %s: type=%s", fset.Position(t.Pos()).String(), n.Name.Name)
arcSrcSet = append(arcSrcSet, s)
pos := fset.Position(structType.Pos())
logs.Debug.Printf("🖋️: %s: overwrite with comment group: type=%s", fset.Position(t.Pos()).String(), n.Name.Name)
arcSrcMap[pos.String()] = &ARCSource{
Source: pos,
Package: f.Name,
TypeSpec: typeSpec,
StructType: structType,
CommentGroup: commentGroup,
}
}
return false
default: // noop
@@ -54,7 +86,18 @@ func extractSource(_ context.Context, fset *token.FileSet, f *goast.File) (ARCSo
}
}

if len(arcSrcSet) == 0 {
arcSrcSet := &ARCSourceSet{
Filename: fset.Position(f.Pos()).Filename,
PackageName: f.Name.Name,
Source: fset.Position(f.Pos()),
ARCSources: make([]*ARCSource, 0),
}

for _, arcSrc := range arcSrcMap {
arcSrcSet.ARCSources = append(arcSrcSet.ARCSources, arcSrc)
}

if len(arcSrcSet.ARCSources) == 0 {
return nil, errorz.Errorf("column-tag-go=%s: %w", config.ColumnTagGo(), apperr.ErrColumnTagGoAnnotationNotFoundInSource)
}

151 changes: 80 additions & 71 deletions internal/arcgen/lang/go/generate.go
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ import (
"go/ast"
"go/format"
"go/token"
"io"
"os"
"reflect"
"strconv"
@@ -17,9 +18,10 @@ import (

"github.com/kunitsucom/arcgen/internal/arcgen/lang/util"
"github.com/kunitsucom/arcgen/internal/config"
"github.com/kunitsucom/arcgen/internal/logs"
)

//nolint:cyclop
//nolint:cyclop,funlen
func Generate(ctx context.Context, src string) error {
arcSrcSets, err := parse(ctx, src)
if err != nil {
@@ -29,15 +31,23 @@ func Generate(ctx context.Context, src string) error {
newFile := token.NewFileSet()

for _, arcSrcSet := range arcSrcSets {
for _, arcSrc := range arcSrcSet {
filePrefix := strings.TrimSuffix(arcSrc.Position.Filename, fileSuffix)
filename := fmt.Sprintf("%s.%s.gen%s", filePrefix, config.ColumnTagGo(), fileSuffix)
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
return errorz.Errorf("os.Open: %w", err)
}
filePrefix := strings.TrimSuffix(arcSrcSet.Filename, fileSuffix)
filename := fmt.Sprintf("%s.%s.gen%s", filePrefix, config.ColumnTagGo(), fileSuffix)
osFile, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
return errorz.Errorf("os.OpenFile: %w", err)
}

packageName := arcSrc.Package.Name
astFile := &ast.File{
// package
Name: &ast.Ident{
Name: arcSrcSet.PackageName,
},
// methods
Decls: []ast.Decl{},
}

for _, arcSrc := range arcSrcSet.ARCSources {
structName := arcSrc.TypeSpec.Name.Name
tableName := extractTableNameFromCommentGroup(arcSrc.CommentGroup)
columnNames := func() []string {
@@ -47,6 +57,7 @@ func Generate(ctx context.Context, src string) error {
tag := reflect.StructTag(strings.Trim(field.Tag.Value, "`"))
switch columnName := tag.Get(config.ColumnTagGo()); columnName {
case "", "-":
logs.Trace.Printf("SKIP: %s: field.Names=%s, columnName=%q", arcSrc.Source.String(), field.Names, columnName)
// noop
default:
columnNames = append(columnNames, columnName)
@@ -56,101 +67,99 @@ func Generate(ctx context.Context, src string) error {
return columnNames
}()

node := generateASTFile(packageName, structName, tableName, config.MethodPrefixGlobal(), config.MethodPrefixColumn(), columnNames)
appendAST(astFile, structName, tableName, config.MethodPrefixGlobal(), config.MethodPrefixColumn(), columnNames)
}

buf := bytes.NewBuffer(nil)
if err := format.Node(buf, newFile, node); err != nil {
return errorz.Errorf("format.Node: %w", err)
}
buf := bytes.NewBuffer(nil)
if err := format.Node(buf, newFile, astFile); err != nil {
return errorz.Errorf("format.Node: %w", err)
}

// add header comment
s := strings.Replace(
buf.String(),
"package "+packageName+"\n",
fmt.Sprintf("// Code generated by arcgen. DO NOT EDIT.\n//\n// source: %s:%d\n\npackage "+packageName+"\n", filepathz.Short(arcSrc.Position.Filename), arcSrc.Position.Line),
1,
)
// add blank line between methods
s = strings.ReplaceAll(s, "\n}\nfunc ", "\n}\n\nfunc ")
// add header comment
content := "" +
"// Code generated by arcgen. DO NOT EDIT." + "\n" +
"//" + "\n" +
fmt.Sprintf("// source: %s", filepathz.Short(arcSrcSet.Source.Filename)) + "\n" +
"\n" +
buf.String()

// write to file
if _, err := f.WriteString(s); err != nil {
return errorz.Errorf("f.WriteString: %w", err)
}
// add blank line between methods
content = strings.ReplaceAll(content, "\n}\nfunc ", "\n}\n\nfunc ")

// write to file
if _, err := io.WriteString(osFile, content); err != nil {
return errorz.Errorf("io.WriteString: %w", err)
}
}

return nil
}

func extractTableNameFromCommentGroup(commentGroup *ast.CommentGroup) string {
for _, comment := range commentGroup.List {
if matches := util.RegexIndexTableName.Regex.FindStringSubmatch(comment.Text); len(matches) > util.RegexIndexTableName.Index {
return matches[util.RegexIndexTableName.Index]
if commentGroup != nil {
for _, comment := range commentGroup.List {
if matches := util.RegexIndexTableName.Regex.FindStringSubmatch(comment.Text); len(matches) > util.RegexIndexTableName.Index {
return matches[util.RegexIndexTableName.Index]
}
}
}
return fmt.Sprintf("ERROR: TABLE NAME IN COMMENT `// \"%s\": table: *` NOT FOUND: comment=%q", config.ColumnTagGo(), commentGroup.Text())

logs.Debug.Printf("WARN: table name in comment not found: `// \"%s\": table: *`: comment=%q", config.ColumnTagGo(), commentGroup.Text())
return ""
}

//nolint:funlen
func generateASTFile(packageName string, structName string, tableName string, prefixGlobal string, prefixColumn string, columnNames []string) *ast.File {
file := &ast.File{
// package
Name: &ast.Ident{
Name: packageName,
},
// methods
Decls: []ast.Decl{
&ast.FuncDecl{
Recv: &ast.FieldList{
List: []*ast.Field{
{
Names: []*ast.Ident{
{
Name: "s",
},
func appendAST(file *ast.File, structName string, tableName string, prefixGlobal string, prefixColumn string, columnNames []string) {
if tableName != "" {
file.Decls = append(file.Decls, &ast.FuncDecl{
Recv: &ast.FieldList{
List: []*ast.Field{
{
Names: []*ast.Ident{
{
Name: "s",
},
Type: &ast.StarExpr{
X: &ast.Ident{
Name: structName, // MEMO: struct name
},
},
Type: &ast.StarExpr{
X: &ast.Ident{
Name: structName, // MEMO: struct name
},
},
},
},
Name: &ast.Ident{
Name: prefixGlobal + "TableName",
},
Type: &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{
List: []*ast.Field{
{
Type: &ast.Ident{
Name: "string",
},
},
Name: &ast.Ident{
Name: prefixGlobal + "TableName",
},
Type: &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{
List: []*ast.Field{
{
Type: &ast.Ident{
Name: "string",
},
},
},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{
&ast.Ident{
Name: strconv.Quote(tableName),
},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{
&ast.Ident{
Name: strconv.Quote(tableName),
},
},
},
},
},
},
})
}

file.Decls = append(file.Decls, generateASTColumnMethods(structName, prefixGlobal, prefixColumn, columnNames)...)

return file
return //nolint:gosimple
}

//nolint:funlen
125 changes: 125 additions & 0 deletions internal/arcgen/lang/go/generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//nolint:testpackage
package arcgengo

import (
"context"
"io"
"os"
"testing"

"github.com/kunitsucom/util.go/testing/assert"
"github.com/kunitsucom/util.go/testing/require"

"github.com/kunitsucom/arcgen/internal/config"
"github.com/kunitsucom/arcgen/internal/contexts"
)

//nolint:paralleltest
func TestGenerate(t *testing.T) {
t.Run("success,tests", func(t *testing.T) {
ctx := contexts.WithArgs(context.Background(), []string{
"ddlgen",
"--column-tag-go=dbtest",
"--method-prefix-global=Get",
// "--src=tests/common.source",
"--src=tests",
})

backup := fileSuffix
t.Cleanup(func() { fileSuffix = backup })

_, err := config.Load(ctx)
require.NoError(t, err)

fileSuffix = ".source"
require.NoError(t, Generate(ctx, config.Source()))

{
expectedFile, err := os.Open("tests/common.golden")
require.NoError(t, err)
expectedBytes, err := io.ReadAll(expectedFile)
require.NoError(t, err)
expected := string(expectedBytes)

actualFile, err := os.Open("tests/common.dbtest.gen.source")
require.NoError(t, err)
actualBytes, err := io.ReadAll(actualFile)
require.NoError(t, err)
actual := string(actualBytes)

assert.Equal(t, expected, actual)
}
})

t.Run("failure,no.errsource", func(t *testing.T) {
ctx := contexts.WithArgs(context.Background(), []string{
"ddlgen",
"--column-tag-go=dbtest",
"--method-prefix-global=Get",
"--src=tests/no.errsource",
})

backup := fileSuffix
t.Cleanup(func() { fileSuffix = backup })

_, err := config.Load(ctx)
require.NoError(t, err)

fileSuffix = ".source"
require.ErrorsContains(t, Generate(ctx, config.Source()), "expected 'package', found 'EOF'")
})

t.Run("failure,no.errsource", func(t *testing.T) {
ctx := contexts.WithArgs(context.Background(), []string{
"ddlgen",
"--column-tag-go=dbtest",
"--method-prefix-global=Get",
"--src=tests",
})

backup := fileSuffix
t.Cleanup(func() { fileSuffix = backup })

_, err := config.Load(ctx)
require.NoError(t, err)

fileSuffix = ".errsource"
require.ErrorsContains(t, Generate(ctx, config.Source()), "expected 'package', found 'EOF'")
})

t.Run("failure,no-such-file-or-directory", func(t *testing.T) {
ctx := contexts.WithArgs(context.Background(), []string{
"ddlgen",
"--column-tag-go=dbtest",
"--method-prefix-global=Get",
"--src=tests/no-such-file-or-directory",
})

backup := fileSuffix
t.Cleanup(func() { fileSuffix = backup })

_, err := config.Load(ctx)
require.NoError(t, err)

fileSuffix = ".source"
require.ErrorsContains(t, Generate(ctx, config.Source()), "no such file or directory")
})

t.Run("failure,directory.dir", func(t *testing.T) {
ctx := contexts.WithArgs(context.Background(), []string{
"ddlgen",
"--column-tag-go=dbtest",
"--method-prefix-global=Get",
"--src=tests/directory.dir",
})

backup := fileSuffix
t.Cleanup(func() { fileSuffix = backup })

_, err := config.Load(ctx)
require.NoError(t, err)

fileSuffix = ".dir"
require.ErrorsContains(t, Generate(ctx, config.Source()), "is a directory")
})
}
2 changes: 1 addition & 1 deletion internal/arcgen/lang/go/parse.go
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ func walkDirFn(ctx context.Context, arcSrcSets *ARCSourceSets) func(path string,
}
}

func parseFile(ctx context.Context, filename string) (ARCSourceSet, error) {
func parseFile(ctx context.Context, filename string) (*ARCSourceSet, error) {
fset := token.NewFileSet()
rootNode, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
if err != nil {
13 changes: 9 additions & 4 deletions internal/arcgen/lang/go/source.go
Original file line number Diff line number Diff line change
@@ -12,16 +12,21 @@ import (

type (
ARCSource struct {
Position token.Position
Package *ast.Ident
Source token.Position // TODO: Unnecessary?
Package *ast.Ident
// TypeSpec is used to guess the table name if the CREATE TABLE annotation is not found.
TypeSpec *ast.TypeSpec
// StructType is used to determine the column name. If the tag specified by --column-tag-go is not found, the field name is used.
StructType *ast.StructType
CommentGroup *ast.CommentGroup
}
ARCSourceSet []*ARCSource
ARCSourceSets []ARCSourceSet
ARCSourceSet struct {
Source token.Position
Filename string
PackageName string
ARCSources []*ARCSource
}
ARCSourceSets []*ARCSourceSet
)

//nolint:gochecknoglobals
1 change: 1 addition & 0 deletions internal/arcgen/lang/go/tests/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.gen.source
3 changes: 3 additions & 0 deletions internal/arcgen/lang/go/tests/column-tag-go.source
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package tests

// dbtest: table: `Tables`
41 changes: 41 additions & 0 deletions internal/arcgen/lang/go/tests/common.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Code generated by arcgen. DO NOT EDIT.
//
// source: tests/common.source

package main

func (s *User) GetTableName() string {
return "`Users`"
}

func (s *User) GetColumnNames() []string {
return []string{"Id", "Name", "Email", "Age"}
}

func (s *User) GetColumnName_Id() string {
return "Id"
}

func (s *User) GetColumnName_Name() string {
return "Name"
}

func (s *User) GetColumnName_Email() string {
return "Email"
}

func (s *User) GetColumnName_Age() string {
return "Age"
}

func (s *Group) GetColumnNames() []string {
return []string{"Id", "Name"}
}

func (s *Group) GetColumnName_Id() string {
return "Id"
}

func (s *Group) GetColumnName_Name() string {
return "Name"
}
36 changes: 36 additions & 0 deletions internal/arcgen/lang/go/tests/common.source
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package main

type (
// User is a user.
//
// dbtest: table: `Users`
User struct {
// ID is a user ID.
ID string `dbtest:"Id"`
// Name is a user name.
Name string `dbtest:"Name"`
// Email is a user email.
Email string `dbtest:"Email"`
// Age is a user age.
Age int `dbtest:"Age"`
// Ignore is a ignore field.
Ignore string `dbtest:"-"`
}

// Users is a slice of User.
//
// dbtest: table: `Users`
Users []*User

// dbtest: table: `InvalidUsers`
InvalidUser struct {
ID string
}

// Group is a group.
//
Group struct {
ID string `dbtest:"Id"`
Name string `dbtest:"Name"`
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*
10 changes: 10 additions & 0 deletions internal/arcgen/lang/go/tests/directory.dir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package main

type (
// ReadOnly is a struct.
//
// dbtest: table: ReadOnly
ReadOnly struct {
Name string `dbtest:"ReadOnly"`
}
)
1 change: 1 addition & 0 deletions internal/arcgen/lang/go/tests/no-column-tag-go.source
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package tests
Empty file.
2 changes: 1 addition & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
@@ -158,7 +158,7 @@ func load(ctx context.Context) (cfg *config, err error) { //nolint:unparam
Name: _OptionMethodPrefixGlobal,
Environment: _EnvKeyMethodPrefixGlobal,
Description: "global method prefix",
Default: cliz.Default("Get"),
Default: cliz.Default(""),
},
&cliz.StringOption{
Name: _OptionMethodPrefixColumn,

0 comments on commit ba2cb2c

Please sign in to comment.