From 35b52762041c227bcedbc3890b5d32f06c268eff Mon Sep 17 00:00:00 2001 From: ginokent <29125616+ginokent@users.noreply.github.com> Date: Mon, 12 Aug 2024 09:48:37 +0900 Subject: [PATCH] fix: Support postgres, mysql diarect --- internal/arcgen/lang/go/generate.go | 37 ++-- .../arcgen/lang/go/generate_crud_common.go | 174 +++++++++++++++--- .../arcgen/lang/go/generate_crud_create.go | 37 ++-- .../arcgen/lang/go/generate_crud_delete.go | 40 ++-- internal/arcgen/lang/go/generate_crud_read.go | 112 ++++++----- .../arcgen/lang/go/generate_crud_update.go | 42 +++-- internal/arcgen/lang/go/source.go | 79 ++++++-- internal/arcgen/lang/util/camel_case.go | 12 ++ internal/config/config.go | 20 ++ internal/config/dialect.go | 18 ++ internal/config/go_crud_type_name.go | 23 +++ 11 files changed, 439 insertions(+), 155 deletions(-) create mode 100644 internal/arcgen/lang/util/camel_case.go create mode 100644 internal/config/dialect.go create mode 100644 internal/config/go_crud_type_name.go diff --git a/internal/arcgen/lang/go/generate.go b/internal/arcgen/lang/go/generate.go index e5ed1a3..db0d6d2 100644 --- a/internal/arcgen/lang/go/generate.go +++ b/internal/arcgen/lang/go/generate.go @@ -57,23 +57,7 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error { if config.GenerateGoCRUDPackage() { crudFileExt := ".crud" + genFileExt - if err := func() error { - filename := filepath.Join(config.GoCRUDPackagePath(), "common"+crudFileExt) - f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, rw_r__r__) - if err != nil { - return errorz.Errorf("os.OpenFile: %w", err) - } - defer f.Close() - - if err := fprintCRUDCommon(f, bytes.NewBuffer(nil), arcSrcSetSlice); err != nil { - return errorz.Errorf("sprint: %w", err) - } - - return nil - }(); err != nil { - return errorz.Errorf("f: %w", err) - } - + crudFiles := make([]string, 0) for _, arcSrcSet := range arcSrcSetSlice { // closure for defer if err := func() error { @@ -84,7 +68,7 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error { return errorz.Errorf("os.OpenFile: %w", err) } defer f.Close() - f.Name() + crudFiles = append(crudFiles, filename) if err := fprintCRUD( f, @@ -98,6 +82,23 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error { return errorz.Errorf("f: %w", err) } } + + if err := func() error { + filename := filepath.Join(config.GoCRUDPackagePath(), "common"+crudFileExt) + f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, rw_r__r__) + if err != nil { + return errorz.Errorf("os.OpenFile: %w", err) + } + defer f.Close() + + if err := fprintCRUDCommon(f, bytes.NewBuffer(nil), arcSrcSetSlice, crudFiles); err != nil { + return errorz.Errorf("sprint: %w", err) + } + + return nil + }(); err != nil { + return errorz.Errorf("f: %w", err) + } } return nil diff --git a/internal/arcgen/lang/go/generate_crud_common.go b/internal/arcgen/lang/go/generate_crud_common.go index af24d04..c532113 100644 --- a/internal/arcgen/lang/go/generate_crud_common.go +++ b/internal/arcgen/lang/go/generate_crud_common.go @@ -2,19 +2,22 @@ package arcgengo import ( "go/ast" + "go/parser" "go/printer" "go/token" "io" + "path/filepath" "strconv" "strings" errorz "github.com/kunitsucom/util.go/errors" + "github.com/kunitsucom/arcgen/internal/arcgen/lang/util" "github.com/kunitsucom/arcgen/internal/config" ) -func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice) error { - content, err := generateCRUDCommonFileContent(buf, arcSrcSetSlice) +func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice, crudFiles []string) error { + content, err := generateCRUDCommonFileContent(buf, arcSrcSetSlice, crudFiles) if err != nil { return errorz.Errorf("generateCRUDCommonFileContent: %w", err) } @@ -27,8 +30,13 @@ func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlic return nil } -//nolint:funlen -func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, error) { +const ( + sqlQueryerContextVarName = "sqlContext" + sqlQueryerContextTypeName = "sqlQueryerContext" +) + +//nolint:cyclop,funlen,gocognit,maintidx +func generateCRUDCommonFileContent(buf buffer, arcSrcSetSlice ARCSourceSetSlice, crudFiles []string) (string, error) { astFile := &ast.File{ // package Name: &ast.Ident{ @@ -38,18 +46,19 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err Decls: []ast.Decl{}, } - // // Since all directories are the same from arcSrcSetSlice[0].Filename to arcSrcSetSlice[len(-1)].Filename, - // // get the package path from arcSrcSetSlice[0].Filename. - // dir := filepath.Dir(arcSrcSetSlice[0].Filename) - // structPackagePath, err := util.GetPackagePath(dir) - // if err != nil { - // return "", errorz.Errorf("GetPackagePath: %w", err) - // } + // Since all directories are the same from arcSrcSetSlice[0].Filename to arcSrcSetSlice[len(-1)].Filename, + // get the package path from arcSrcSetSlice[0].Filename. + dir := filepath.Dir(arcSrcSetSlice[0].Filename) + structPackagePath, err := util.GetPackagePath(dir) + if err != nil { + return "", errorz.Errorf("GetPackagePath: %w", err) + } astFile.Decls = append(astFile.Decls, // import ( // "context" // "database/sql" + // "log/slog" // // dao "path/to/your/dao" // ) @@ -62,15 +71,18 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err &ast.ImportSpec{ Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("database/sql")}, }, - // &ast.ImportSpec{ - // Name: &ast.Ident{Name: "dao"}, - // Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(structPackagePath)}, - // }, + &ast.ImportSpec{ + Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("log/slog")}, + }, + &ast.ImportSpec{ + Name: &ast.Ident{Name: "dao"}, + Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(structPackagePath)}, + }, }, }, ) - // type sqlContext interface { + // type sqlQueryerContext interface { // QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) // QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row // ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) @@ -80,7 +92,8 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err Tok: token.TYPE, Specs: []ast.Spec{ &ast.TypeSpec{ - Name: &ast.Ident{Name: "sqlContext"}, + // Assign: token.Pos(1), + Name: &ast.Ident{Name: sqlQueryerContextTypeName}, Type: &ast.InterfaceType{ Methods: &ast.FieldList{ List: []*ast.Field{ @@ -133,27 +146,138 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err }, ) - // type Queryer struct {} + // type _CRUD struct { + // } astFile.Decls = append(astFile.Decls, &ast.GenDecl{ Tok: token.TYPE, Specs: []ast.Spec{ &ast.TypeSpec{ - Name: &ast.Ident{Name: "Queryer"}, + Name: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}, Type: &ast.StructType{Fields: &ast.FieldList{}}, }, }, }, ) - // func NewQueryer() *Query { - // return &Queryer{} - // } + // func LoggerFromContext(ctx context.Context) *slog.Logger { + // if ctx == nil { + // return slog.Default() + // } + // if logger, ok := ctx.Value((*slog.Logger)(nil)).(*slog.Logger); ok { + // return logger + // } + // return slog.Default() + // } + astFile.Decls = append(astFile.Decls, + &ast.FuncDecl{ + Name: &ast.Ident{Name: "LoggerFromContext"}, + Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}}}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}}}}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.IfStmt{ + Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "ctx"}, Op: token.EQL, Y: &ast.Ident{Name: "nil"}}, + Body: &ast.BlockStmt{List: []ast.Stmt{ + &ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Default"}}}}}, + }}, + }, + &ast.IfStmt{ + // if logger, ok := ctx.Value((*slog.Logger)(nil)).(*slog.Logger); ok { + Init: &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.Ident{Name: "logger"}, &ast.Ident{Name: "ok"}}, + Tok: token.DEFINE, + Rhs: []ast.Expr{ + &ast.TypeAssertExpr{ + X: &ast.CallExpr{ + Fun: &ast.Ident{Name: "ctx.Value"}, + Args: []ast.Expr{&ast.CallExpr{Fun: &ast.ParenExpr{X: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}, Args: []ast.Expr{&ast.Ident{Name: "nil"}}}}, + }, + Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}, + }, + }, + }, + Cond: &ast.Ident{Name: "ok"}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.Ident{Name: "logger"}}}}}, + }, + &ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Default"}}}}}, + }, + }, + }, + ) + + // func LoggerWithContext(ctx context.Context, logger *slog.Logger) context.Context { + // return context.WithValue(ctx, (*slog.Logger)(nil), logger) + // } + astFile.Decls = append(astFile.Decls, + &ast.FuncDecl{ + Name: &ast.Ident{Name: "LoggerWithContext"}, + Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}, {Names: []*ast.Ident{{Name: "logger"}}, Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}}}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: "context.Context"}}}}}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "context"}, Sel: &ast.Ident{Name: "WithValue"}}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}, &ast.CallExpr{Fun: &ast.ParenExpr{X: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}, Args: []ast.Expr{&ast.Ident{Name: "nil"}}}, &ast.Ident{Name: "logger"}}}}}, + }, + }, + }, + ) + + // type CRUD interface { + // Create{StructName}(ctx context.Context, sqlQueryer sqlQueryerContext, s *{Struct}) error + // ... + // } + methods := make([]*ast.Field, 0) + fset := token.NewFileSet() + for _, crudFile := range crudFiles { + rootNode, err := parser.ParseFile(fset, crudFile, nil, parser.ParseComments) + if err != nil { + // MEMO: parser.ParseFile err contains file path, so no need to log it + return "", errorz.Errorf("parser.ParseFile: %w", err) + } + + // MEMO: Inspect is used to get the method declaration from the file + ast.Inspect(rootNode, func(node ast.Node) bool { + switch n := node.(type) { + case *ast.FuncDecl: + //nolint:nestif + if n.Recv != nil && len(n.Recv.List) > 0 { + if t, ok := n.Recv.List[0].Type.(*ast.StarExpr); ok { + if ident, ok := t.X.(*ast.Ident); ok { + if ident.Name == config.GoCRUDTypeNameUnexported() { + methods = append(methods, &ast.Field{ + Names: []*ast.Ident{{Name: n.Name.Name}}, + Type: n.Type, + }) + } + } + } + } + default: + // noop + } + return true + }) + } + astFile.Decls = append(astFile.Decls, + &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{ + &ast.TypeSpec{ + Name: &ast.Ident{Name: config.GoCRUDTypeName()}, + Type: &ast.InterfaceType{ + Methods: &ast.FieldList{List: methods}, + }, + }, + }, + }, + ) + + // func NewCRUD() CRUD { + // return &_CRUD{} + // } astFile.Decls = append(astFile.Decls, &ast.FuncDecl{ - Name: &ast.Ident{Name: "NewQueryer"}, - Type: &ast.FuncType{Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}}}, - Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.Ident{Name: "Queryer{}"}}}}}}, + Name: &ast.Ident{Name: "New" + config.GoCRUDTypeName()}, + Type: &ast.FuncType{Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: config.GoCRUDTypeName()}}}}}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported() + "{}"}}}}}}, }, ) diff --git a/internal/arcgen/lang/go/generate_crud_create.go b/internal/arcgen/lang/go/generate_crud_create.go index 07aa70a..9ecf4c5 100644 --- a/internal/arcgen/lang/go/generate_crud_create.go +++ b/internal/arcgen/lang/go/generate_crud_create.go @@ -5,6 +5,8 @@ import ( "go/token" "strconv" "strings" + + "github.com/kunitsucom/arcgen/internal/config" ) //nolint:funlen @@ -13,13 +15,14 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { structName := arcSrc.extractStructName() tableName := arcSrc.extractTableNameFromCommentGroup() tableInfo := arcSrc.extractFieldNamesAndColumnNames() - columnNames := tableInfo.ColumnNames() + columnNames := tableInfo.Columns.ColumnNames() - // const Create{StructName}Query = `INSERT INTO {table_name} ({column_name1}, {column_name2}) VALUES (?, ?)` + // const Create{StructName}Query = `INSERT INTO {table_name} ({column_name1}, {column_name2}) VALUES ($1, $2)` // - // func (q *query) Create{StructName}(ctx context.Context, queryer sqlContext, s *{Struct}) error { - // if _, err := queryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil { - // return fmt.Errorf("q.queryer.ExecContext: %w", err) + // func (q *query) Create{StructName}(ctx context.Context, queryer sqlQueryerContext, s *{Struct}) error { + // LoggerFromContext(ctx).Debug(Create{StructName}Query) + // if _, err := sqlContext.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil { + // return fmt.Errorf("sqlContext.ExecContext: %w", err) // } // return nil // } @@ -33,18 +36,18 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Names: []*ast.Ident{{Name: queryName}}, Values: []ast.Expr{&ast.BasicLit{ Kind: token.STRING, - Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (?" + strings.Repeat(", ?", len(columnNames)-1) + ")`", + Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (" + columnValuesPlaceholder(columnNames) + ")`", }}, }, }, }, &ast.FuncDecl{ - Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}}, + Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}}}}, Name: &ast.Ident{Name: funcName}, Type: &ast.FuncType{ Params: &ast.FieldList{List: []*ast.Field{ {Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}, - {Names: []*ast.Ident{{Name: "sqlCtx"}}, Type: &ast.Ident{Name: "sqlContext"}}, + {Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, Type: &ast.Ident{Name: sqlQueryerContextTypeName}}, {Names: []*ast.Ident{{Name: "s"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "dao." + structName}}}, }}, Results: &ast.FieldList{List: []*ast.Field{ @@ -53,14 +56,24 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { }, Body: &ast.BlockStmt{ List: []ast.Stmt{ + &ast.ExprStmt{ + // LoggerFromContext(ctx).Debug(queryName) + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}}, + Sel: &ast.Ident{Name: "Debug"}, + }, + Args: []ast.Expr{&ast.Ident{Name: queryName}}, + }, + }, &ast.IfStmt{ - // if _, err := queryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil { + // if _, err := sqlQueryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil { Init: &ast.AssignStmt{ Lhs: []ast.Expr{&ast.Ident{Name: "_"}, &ast.Ident{Name: "err"}}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{ - X: &ast.Ident{Name: "sqlCtx"}, + X: &ast.Ident{Name: sqlQueryerContextVarName}, Sel: &ast.Ident{Name: "ExecContext"}, }, Args: append( @@ -80,10 +93,10 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // err != nil { Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "err"}, Op: token.NEQ, Y: &ast.Ident{Name: "nil"}}, Body: &ast.BlockStmt{List: []ast.Stmt{ - // return fmt.Errorf("queryer.ExecContext: %w", err) + // return fmt.Errorf("sqlContext.ExecContext: %w", err) &ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "fmt"}, Sel: &ast.Ident{Name: "Errorf"}}, - Args: []ast.Expr{&ast.Ident{Name: strconv.Quote("queryer.ExecContext: %w")}, &ast.Ident{Name: "err"}}, + Args: []ast.Expr{&ast.Ident{Name: strconv.Quote(sqlQueryerContextVarName + ".ExecContext: %w")}, &ast.Ident{Name: "err"}}, }}}, }}, }, diff --git a/internal/arcgen/lang/go/generate_crud_delete.go b/internal/arcgen/lang/go/generate_crud_delete.go index c1803ee..125f890 100644 --- a/internal/arcgen/lang/go/generate_crud_delete.go +++ b/internal/arcgen/lang/go/generate_crud_delete.go @@ -4,7 +4,9 @@ import ( "go/ast" "go/token" "strconv" - "strings" + + "github.com/kunitsucom/arcgen/internal/arcgen/lang/util" + "github.com/kunitsucom/arcgen/internal/config" ) //nolint:funlen @@ -16,15 +18,15 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // const Delete{StructName}Query = `DELETE FROM {table_name} WHERE {pk1} = ? [AND {pk2} = ?]` // - // func (q *query) Delete{StructName}(ctx context.Context, queryer sqlContext, pk1 pk1type [, pk2 pk2type]) error { - // if _, err := queryer.ExecContext(ctx, Delete{StructName}Query, pk1 [, pk2]); err != nil { - // return fmt.Errorf("q.queryer.ExecContext: %w", err) + // func (q *query) Delete{StructName}(ctx context.Context, queryer sqlQueryerContext, pk1 pk1type [, pk2 pk2type]) error { + // if _, err := sqlContext.ExecContext(ctx, Delete{StructName}Query, pk1 [, pk2]); err != nil { + // return fmt.Errorf("sqlContext.ExecContext: %w", err) // } // return nil // } funcName := "Delete" + structName + "ByPK" queryName := funcName + "Query" - pkColumns := tableInfo.PrimaryKeys() + pkColumns := tableInfo.Columns.PrimaryKeys() pkColumnNames := func() (pkColumnNames []string) { for _, c := range pkColumns { pkColumnNames = append(pkColumnNames, c.ColumnName) @@ -39,24 +41,24 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Names: []*ast.Ident{{Name: queryName}}, Values: []ast.Expr{&ast.BasicLit{ Kind: token.STRING, - Value: "`DELETE FROM " + tableName + " WHERE " + strings.Join(pkColumnNames, " = ? AND ") + " = ?`", + Value: "`DELETE FROM " + tableName + " WHERE " + whereColumnsPlaceholder(pkColumnNames, "AND") + "`", }}, }, }, }, &ast.FuncDecl{ - Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}}, + Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}}}}, Name: &ast.Ident{Name: funcName}, Type: &ast.FuncType{ Params: &ast.FieldList{List: append( []*ast.Field{ {Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}, - {Names: []*ast.Ident{{Name: "sqlCtx"}}, Type: &ast.Ident{Name: "sqlContext"}}, + {Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, Type: &ast.Ident{Name: sqlQueryerContextTypeName}}, }, func() []*ast.Field { var fields []*ast.Field for _, c := range pkColumns { - fields = append(fields, &ast.Field{Names: []*ast.Ident{{Name: c.ColumnName}}, Type: &ast.Ident{Name: c.FieldType}}) + fields = append(fields, &ast.Field{Names: []*ast.Ident{{Name: util.PascalCaseToCamelCase(c.FieldName)}}, Type: &ast.Ident{Name: c.FieldType}}) } return fields }()..., @@ -67,14 +69,24 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { }, Body: &ast.BlockStmt{ List: []ast.Stmt{ + &ast.ExprStmt{ + // LoggerFromContext(ctx).Debug(queryName) + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}}, + Sel: &ast.Ident{Name: "Debug"}, + }, + Args: []ast.Expr{&ast.Ident{Name: queryName}}, + }, + }, &ast.IfStmt{ - // if _, err := queryer.ExecContext(ctx, Delete{StructName}Query, pk1 [, pk2]); err != nil { + // if _, err := sqlContext.ExecContext(ctx, Delete{StructName}Query, pk1 [, pk2]); err != nil { Init: &ast.AssignStmt{ Lhs: []ast.Expr{&ast.Ident{Name: "_"}, &ast.Ident{Name: "err"}}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{ - X: &ast.Ident{Name: "sqlCtx"}, + X: &ast.Ident{Name: sqlQueryerContextVarName}, Sel: &ast.Ident{Name: "ExecContext"}, }, Args: append( @@ -85,7 +97,7 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { func() []ast.Expr { var args []ast.Expr for _, c := range pkColumns { - args = append(args, &ast.Ident{Name: c.ColumnName}) + args = append(args, &ast.Ident{Name: util.PascalCaseToCamelCase(c.FieldName)}) } return args }()..., @@ -95,10 +107,10 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // err != nil { Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "err"}, Op: token.NEQ, Y: &ast.Ident{Name: "nil"}}, Body: &ast.BlockStmt{List: []ast.Stmt{ - // return fmt.Errorf("queryer.ExecContext: %w", err) + // return fmt.Errorf("sqlContext.ExecContext: %w", err) &ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "fmt"}, Sel: &ast.Ident{Name: "Errorf"}}, - Args: []ast.Expr{&ast.Ident{Name: strconv.Quote("queryer.ExecContext: %w")}, &ast.Ident{Name: "err"}}, + Args: []ast.Expr{&ast.Ident{Name: strconv.Quote(sqlQueryerContextVarName + ".ExecContext: %w")}, &ast.Ident{Name: "err"}}, }}}, }}, }, diff --git a/internal/arcgen/lang/go/generate_crud_read.go b/internal/arcgen/lang/go/generate_crud_read.go index c753c27..56ba83f 100644 --- a/internal/arcgen/lang/go/generate_crud_read.go +++ b/internal/arcgen/lang/go/generate_crud_read.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" + "github.com/kunitsucom/arcgen/internal/arcgen/lang/util" "github.com/kunitsucom/arcgen/internal/config" ) @@ -15,12 +16,12 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { structName := arcSrc.extractStructName() tableName := arcSrc.extractTableNameFromCommentGroup() tableInfo := arcSrc.extractFieldNamesAndColumnNames() - columnNames := tableInfo.ColumnNames() - pks := tableInfo.PrimaryKeys() + columnNames := tableInfo.Columns.ColumnNames() + pks := tableInfo.Columns.PrimaryKeys() // const Find{StructName}ByPKQuery = `SELECT {column_name1}, {column_name2} FROM {table_name} WHERE {pk1} = ? [AND ...]` // - // func (q *query) Find{StructName}ByPK(ctx context.Context, queryer sqlContext, pk1 pk1type, ...) ({Struct}, error) { + // func (q *query) Find{StructName}ByPK(ctx context.Context, queryer sqlQueryerContext, pk1 pk1type, ...) ({Struct}, error) { // row := queryer.QueryRowContext(ctx, Find{StructName}Query, pk1, ...) // var s {Struct} // if err := row.Scan( @@ -40,14 +41,8 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { &ast.ValueSpec{ Names: []*ast.Ident{{Name: byPKQueryName}}, Values: []ast.Expr{&ast.BasicLit{ - Kind: token.STRING, - Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + func() string { - var where []string - for _, pk := range pks { - where = append(where, pk.ColumnName+" = ?") - } - return strings.Join(where, " AND ") - }() + "`", + Kind: token.STRING, + Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(pks.ColumnNames(), "AND") + "`", }}, }, }, @@ -56,7 +51,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Name: &ast.Ident{Name: byPKFuncName}, Recv: &ast.FieldList{List: []*ast.Field{{ Names: []*ast.Ident{{Name: "q"}}, - Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}, + Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}, }}}, Type: &ast.FuncType{ Params: &ast.FieldList{ @@ -66,15 +61,15 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Type: &ast.Ident{Name: "context.Context"}, }, { - Names: []*ast.Ident{{Name: "sqlCtx"}}, - Type: &ast.Ident{Name: "sqlContext"}, + Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, + Type: &ast.Ident{Name: sqlQueryerContextTypeName}, }, }, func() []*ast.Field { fields := make([]*ast.Field, 0) for _, pk := range pks { fields = append(fields, &ast.Field{ - Names: []*ast.Ident{{Name: pk.ColumnName}}, + Names: []*ast.Ident{{Name: util.PascalCaseToCamelCase(pk.FieldName)}}, Type: &ast.Ident{Name: pk.FieldType}, }) } @@ -89,11 +84,21 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Body: &ast.BlockStmt{ // row, err := queryer.QueryRowContext(ctx, Find{StructName}Query, pk1, ...) List: []ast.Stmt{ + &ast.ExprStmt{ + // LoggerFromContext(ctx).Debug(queryName) + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}}, + Sel: &ast.Ident{Name: "Debug"}, + }, + Args: []ast.Expr{&ast.Ident{Name: byPKQueryName}}, + }, + }, &ast.AssignStmt{ Lhs: []ast.Expr{&ast.Ident{Name: "row"}}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "sqlCtx"}, Sel: &ast.Ident{Name: "QueryRowContext"}}, + Fun: &ast.SelectorExpr{X: &ast.Ident{Name: sqlQueryerContextVarName}, Sel: &ast.Ident{Name: "QueryRowContext"}}, Args: append( []ast.Expr{ &ast.Ident{Name: "ctx"}, @@ -102,7 +107,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { func() []ast.Expr { var args []ast.Expr for _, pk := range pks { - args = append(args, &ast.Ident{Name: pk.ColumnName}) + args = append(args, &ast.Ident{Name: util.PascalCaseToCamelCase(pk.FieldName)}) } return args }()...), @@ -162,7 +167,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { for _, hasOneTag := range tableInfo.HasOneTags { // const Find{StructName}By{FieldName}Query = `SELECT {column_name1}, {column_name2} FROM {table_name} WHERE {column} = ? [AND ...]` // - // func (q *Queryer) Find{StructName}ByColumn1[AndColumn2](ctx context.Context, queryer sqlContext, {ColumnName} {ColumnType} [, {Column2Name} {Column2Type}]) ({Struct}Slice, error) { + // func (q *queryer) Find{StructName}ByColumn1[AndColumn2](ctx context.Context, queryer sqlQueryerContext, {ColumnName} {ColumnType} [, {Column2Name} {Column2Type}]) ({Struct}Slice, error) { // row := queryer.QueryRowContext(ctx, Find{StructName}Query, {ColumnName}, {Column2Name}) // var s {Struct} // if err := row.Scan( @@ -183,14 +188,8 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { &ast.ValueSpec{ Names: []*ast.Ident{{Name: byHasOneTagQueryName}}, Values: []ast.Expr{&ast.BasicLit{ - Kind: token.STRING, - Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + func() string { - var where []string - for _, hasOneColumn := range hasOneColumns { - where = append(where, hasOneColumn.ColumnName+" = ?") - } - return strings.Join(where, " AND ") - }() + "`", + Kind: token.STRING, + Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasOneColumns.ColumnNames(), "AND") + "`", }}, }, }, @@ -199,7 +198,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Name: &ast.Ident{Name: byHasOneTagFuncName}, Recv: &ast.FieldList{List: []*ast.Field{{ Names: []*ast.Ident{{Name: "q"}}, - Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}, + Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}, }}}, Type: &ast.FuncType{ Params: &ast.FieldList{ @@ -209,16 +208,16 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Type: &ast.Ident{Name: "context.Context"}, }, { - Names: []*ast.Ident{{Name: "sqlCtx"}}, - Type: &ast.Ident{Name: "sqlContext"}, + Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, + Type: &ast.Ident{Name: sqlQueryerContextTypeName}, }, }, func() []*ast.Field { fields := make([]*ast.Field, 0) - for _, hasOneColumn := range hasOneColumns { + for _, c := range hasOneColumns { fields = append(fields, &ast.Field{ - Names: []*ast.Ident{{Name: hasOneColumn.ColumnName}}, - Type: &ast.Ident{Name: hasOneColumn.FieldType}, + Names: []*ast.Ident{{Name: util.PascalCaseToCamelCase(c.FieldName)}}, + Type: &ast.Ident{Name: c.FieldType}, }) } return fields @@ -232,11 +231,21 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Body: &ast.BlockStmt{ // row, err := queryer.QueryRowContext(ctx, Find{StructName}Query, column1, ...) List: []ast.Stmt{ + &ast.ExprStmt{ + // LoggerFromContext(ctx).Debug(queryName) + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}}, + Sel: &ast.Ident{Name: "Debug"}, + }, + Args: []ast.Expr{&ast.Ident{Name: byHasOneTagQueryName}}, + }, + }, &ast.AssignStmt{ Lhs: []ast.Expr{&ast.Ident{Name: "row"}}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "sqlCtx"}, Sel: &ast.Ident{Name: "QueryRowContext"}}, + Fun: &ast.SelectorExpr{X: &ast.Ident{Name: sqlQueryerContextVarName}, Sel: &ast.Ident{Name: "QueryRowContext"}}, Args: append( []ast.Expr{ &ast.Ident{Name: "ctx"}, @@ -245,7 +254,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { func() []ast.Expr { var args []ast.Expr for _, c := range hasOneColumns { - args = append(args, &ast.Ident{Name: c.ColumnName}) + args = append(args, &ast.Ident{Name: util.PascalCaseToCamelCase(c.FieldName)}) } return args }()...), @@ -306,7 +315,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { for _, hasManyTag := range tableInfo.HasManyTags { // const List{StructName}By{FieldName}Query = `SELECT {column_name1}, {column_name2} FROM {table_name} WHERE {pk1} = ? [AND ...]` // - // func (q *query) List{StructName}ByColumn1[AndColumn2](ctx context.Context, queryer sqlContext, {ColumnName} {ColumnType} [, {Column2Name} {Column2Type}]) ({Struct}Slice, error) { + // func (q *query) List{StructName}ByColumn1[AndColumn2](ctx context.Context, queryer sqlQueryerContext, {ColumnName} {ColumnType} [, {Column2Name} {Column2Type}]) ({Struct}Slice, error) { // rows, err := queryer.QueryContext(ctx, List{StructName}Query, {ColumnName}, {Column2Name}) // if err != nil { // return nil, fmt.Errorf("queryer.QueryContext: %w", err) @@ -344,14 +353,8 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { &ast.ValueSpec{ Names: []*ast.Ident{{Name: byHasOneTagQueryName}}, Values: []ast.Expr{&ast.BasicLit{ - Kind: token.STRING, - Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + func() string { - var where []string - for _, c := range hasManyColumns { - where = append(where, c.ColumnName+" = ?") - } - return strings.Join(where, " AND ") - }() + "`", + Kind: token.STRING, + Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasManyColumns.ColumnNames(), "AND") + "`", }}, }, }, @@ -360,22 +363,19 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Name: &ast.Ident{Name: byHasOneTagFuncName}, Recv: &ast.FieldList{List: []*ast.Field{{ Names: []*ast.Ident{{Name: "q"}}, - Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}, + Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}, }}}, Type: &ast.FuncType{ Params: &ast.FieldList{ List: append( []*ast.Field{ {Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}, - {Names: []*ast.Ident{{Name: "sqlCtx"}}, Type: &ast.Ident{Name: "sqlContext"}}, + {Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, Type: &ast.Ident{Name: sqlQueryerContextTypeName}}, }, func() []*ast.Field { fields := make([]*ast.Field, 0) for _, c := range hasManyColumns { - fields = append(fields, &ast.Field{ - Names: []*ast.Ident{{Name: c.ColumnName}}, - Type: &ast.Ident{Name: c.FieldType}, - }) + fields = append(fields, &ast.Field{Names: []*ast.Ident{{Name: util.PascalCaseToCamelCase(c.FieldName)}}, Type: &ast.Ident{Name: c.FieldType}}) } return fields }()..., @@ -388,11 +388,21 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { }, Body: &ast.BlockStmt{ List: []ast.Stmt{ + &ast.ExprStmt{ + // LoggerFromContext(ctx).Debug(queryName) + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}}, + Sel: &ast.Ident{Name: "Debug"}, + }, + Args: []ast.Expr{&ast.Ident{Name: byHasOneTagQueryName}}, + }, + }, &ast.AssignStmt{ Lhs: []ast.Expr{&ast.Ident{Name: "rows"}, &ast.Ident{Name: "err"}}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "sqlCtx"}, Sel: &ast.Ident{Name: "QueryContext"}}, + Fun: &ast.SelectorExpr{X: &ast.Ident{Name: sqlQueryerContextVarName}, Sel: &ast.Ident{Name: "QueryContext"}}, Args: append( []ast.Expr{ &ast.Ident{Name: "ctx"}, @@ -401,7 +411,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { func() []ast.Expr { var args []ast.Expr for _, c := range hasManyColumns { - args = append(args, &ast.Ident{Name: c.ColumnName}) + args = append(args, &ast.Ident{Name: util.PascalCaseToCamelCase(c.FieldName)}) } return args }()..., diff --git a/internal/arcgen/lang/go/generate_crud_update.go b/internal/arcgen/lang/go/generate_crud_update.go index 26a9cf4..da41a73 100644 --- a/internal/arcgen/lang/go/generate_crud_update.go +++ b/internal/arcgen/lang/go/generate_crud_update.go @@ -5,6 +5,8 @@ import ( "go/token" "strconv" "strings" + + "github.com/kunitsucom/arcgen/internal/config" ) //nolint:funlen @@ -16,22 +18,16 @@ func generateUPDATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // const Update{StructName}Query = `UPDATE {table_name} SET ({column_name1}, {column_name2}) = (?, ?) WHERE {pk1} = ? [AND {pk2} = ?]` // - // func (q *query) Update{StructName}(ctx context.Context, queryer sqlContext, s *{Struct}) error { - // if _, err := queryer.ExecContext(ctx, Update{StructName}Query, s.{ColumnName1}, s.{ColumnName2}, s.{PK1} [, s.{PK2}]); err != nil { - // return fmt.Errorf("q.queryer.ExecContext: %w", err) + // func (q *query) Update{StructName}(ctx context.Context, queryer sqlQueryerContext, s *{Struct}) error { + // if _, err := sqlContext.ExecContext(ctx, Update{StructName}Query, s.{ColumnName1}, s.{ColumnName2}, s.{PK1} [, s.{PK2}]); err != nil { + // return fmt.Errorf("sqlContext.ExecContext: %w", err) // } // return nil // } funcName := "Update" + structName queryName := funcName + "Query" - pkColumns := tableInfo.PrimaryKeys() - pkColumnNames := func() (pkColumnNames []string) { - for _, c := range pkColumns { - pkColumnNames = append(pkColumnNames, c.ColumnName) - } - return pkColumnNames - }() - nonPKColumns := tableInfo.NonPrimaryKeys() + pkColumns := tableInfo.Columns.PrimaryKeys() + nonPKColumns := tableInfo.Columns.NonPrimaryKeys() nonPKColumnNames := func() (nonPKColumnNames []string) { for _, c := range nonPKColumns { nonPKColumnNames = append(nonPKColumnNames, c.ColumnName) @@ -46,18 +42,18 @@ func generateUPDATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Names: []*ast.Ident{{Name: queryName}}, Values: []ast.Expr{&ast.BasicLit{ Kind: token.STRING, - Value: "`UPDATE " + tableName + " SET (" + strings.Join(nonPKColumnNames, ", ") + ") = (?" + strings.Repeat(", ?", len(nonPKColumns)-1) + ") WHERE " + strings.Join(pkColumnNames, " = ? AND ") + " = ?`", + Value: "`UPDATE " + tableName + " SET (" + strings.Join(nonPKColumnNames, ", ") + ") = (?" + strings.Repeat(", ?", len(nonPKColumns)-1) + ") WHERE " + whereColumnsPlaceholder(pkColumns.ColumnNames(), "AND") + "`", }}, }, }, }, &ast.FuncDecl{ - Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}}, + Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}}}}, Name: &ast.Ident{Name: funcName}, Type: &ast.FuncType{ Params: &ast.FieldList{List: []*ast.Field{ {Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}, - {Names: []*ast.Ident{{Name: "sqlCtx"}}, Type: &ast.Ident{Name: "sqlContext"}}, + {Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, Type: &ast.Ident{Name: sqlQueryerContextTypeName}}, {Names: []*ast.Ident{{Name: "s"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "dao." + structName}}}, }}, Results: &ast.FieldList{List: []*ast.Field{ @@ -66,14 +62,24 @@ func generateUPDATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { }, Body: &ast.BlockStmt{ List: []ast.Stmt{ + &ast.ExprStmt{ + // LoggerFromContext(ctx).Debug(queryName) + X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}}, + Sel: &ast.Ident{Name: "Debug"}, + }, + Args: []ast.Expr{&ast.Ident{Name: queryName}}, + }, + }, &ast.IfStmt{ - // if _, err := queryer.ExecContext(ctx, Update{StructName}Query, s.{ColumnName1}, s.{ColumnName2}, s.{PK1} [, s.{PK2}]); err != nil { + // if _, err := sqlContext.ExecContext(ctx, Update{StructName}Query, s.{ColumnName1}, s.{ColumnName2}, s.{PK1} [, s.{PK2}]); err != nil { Init: &ast.AssignStmt{ Lhs: []ast.Expr{&ast.Ident{Name: "_"}, &ast.Ident{Name: "err"}}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{ - X: &ast.Ident{Name: "sqlCtx"}, + X: &ast.Ident{Name: sqlQueryerContextVarName}, Sel: &ast.Ident{Name: "ExecContext"}, }, Args: append( @@ -102,10 +108,10 @@ func generateUPDATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // err != nil { Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "err"}, Op: token.NEQ, Y: &ast.Ident{Name: "nil"}}, Body: &ast.BlockStmt{List: []ast.Stmt{ - // return fmt.Errorf("queryer.ExecContext: %w", err) + // return fmt.Errorf("sqlContext.ExecContext: %w", err) &ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "fmt"}, Sel: &ast.Ident{Name: "Errorf"}}, - Args: []ast.Expr{&ast.Ident{Name: strconv.Quote("queryer.ExecContext: %w")}, &ast.Ident{Name: "err"}}, + Args: []ast.Expr{&ast.Ident{Name: strconv.Quote(sqlQueryerContextVarName + ".ExecContext: %w")}, &ast.Ident{Name: "err"}}, }}}, }}, }, diff --git a/internal/arcgen/lang/go/source.go b/internal/arcgen/lang/go/source.go index 332f6a2..cc0b7ce 100644 --- a/internal/arcgen/lang/go/source.go +++ b/internal/arcgen/lang/go/source.go @@ -7,6 +7,7 @@ import ( "reflect" "regexp" "slices" + "strconv" "strings" "sync" @@ -77,20 +78,22 @@ func (a *ARCSource) extractTableNameFromCommentGroup() string { type TableInfo struct { HasOneTags []string HasManyTags []string - Columns []*ColumnInfo + Columns ColumnInfos } -func (t *TableInfo) ColumnNames() []string { - columnNames := make([]string, len(t.Columns)) - for i := range t.Columns { - columnNames[i] = t.Columns[i].ColumnName +type ColumnInfos []*ColumnInfo + +func (ss ColumnInfos) ColumnNames() []string { + columnNames := make([]string, len(ss)) + for i := range ss { + columnNames[i] = ss[i].ColumnName } return columnNames } -func (t *TableInfo) PrimaryKeys() []*ColumnInfo { - pks := make([]*ColumnInfo, 0, len(t.Columns)) - for _, column := range t.Columns { +func (ss ColumnInfos) PrimaryKeys() ColumnInfos { + pks := make(ColumnInfos, 0, len(ss)) + for _, column := range ss { if column.PK { pks = append(pks, column) } @@ -98,9 +101,9 @@ func (t *TableInfo) PrimaryKeys() []*ColumnInfo { return pks } -func (t *TableInfo) NonPrimaryKeys() []*ColumnInfo { - nonPks := make([]*ColumnInfo, 0, len(t.Columns)) - for _, column := range t.Columns { +func (ss ColumnInfos) NonPrimaryKeys() ColumnInfos { + nonPks := make(ColumnInfos, 0, len(ss)) + for _, column := range ss { if !column.PK { nonPks = append(nonPks, column) } @@ -108,10 +111,10 @@ func (t *TableInfo) NonPrimaryKeys() []*ColumnInfo { return nonPks } -func (t *TableInfo) HasOneTagColumnsByTag() map[string][]*ColumnInfo { - columns := make(map[string][]*ColumnInfo) +func (t *TableInfo) HasOneTagColumnsByTag() map[string]ColumnInfos { + columns := make(map[string]ColumnInfos) for _, hasOneTagInTable := range t.HasOneTags { - columns[hasOneTagInTable] = make([]*ColumnInfo, 0, len(t.Columns)) + columns[hasOneTagInTable] = make(ColumnInfos, 0, len(t.Columns)) for _, column := range t.Columns { for _, hasOneTag := range column.HasOneTags { if hasOneTagInTable == hasOneTag { @@ -124,10 +127,10 @@ func (t *TableInfo) HasOneTagColumnsByTag() map[string][]*ColumnInfo { return columns } -func (t *TableInfo) HasManyTagColumnsByTag() map[string][]*ColumnInfo { - columns := make(map[string][]*ColumnInfo) +func (t *TableInfo) HasManyTagColumnsByTag() map[string]ColumnInfos { + columns := make(map[string]ColumnInfos) for _, hasManyTagInTable := range t.HasManyTags { - columns[hasManyTagInTable] = make([]*ColumnInfo, 0, len(t.Columns)) + columns[hasManyTagInTable] = make(ColumnInfos, 0, len(t.Columns)) for _, column := range t.Columns { for _, hasManyTag := range column.HasManyTags { if hasManyTagInTable == hasManyTag { @@ -149,6 +152,48 @@ type ColumnInfo struct { HasManyTags []string } +func columnValuesPlaceholder(columns []string) string { + switch config.Dialect() { + case "mysql": + // ?, ?, ?, ... + return "?" + strings.Repeat(", ?", len(columns)-1) + default: + return func() string { + // $1, $2, $3, ... + var s strings.Builder + s.WriteString("$1") + for i := 2; i <= len(columns); i++ { + s.WriteString(", $") + s.WriteString(strconv.Itoa(i)) + } + return s.String() + }() + } +} + +//nolint:unparam +func whereColumnsPlaceholder(columns []string, op string) string { + switch config.Dialect() { + case "mysql": + // column1 = ? AND column2 = ? AND column3 = ... + return strings.Join(columns, " = ? "+op+" ") + " = ?" + default: + return func() string { + // column1 = $1 AND column2 = $2 AND column3 = ... + var s strings.Builder + for i, column := range columns { + if i > 0 { + s.WriteString(" " + op + " ") + } + s.WriteString(column) + s.WriteString(" = $") + s.WriteString(strconv.Itoa(i + 1)) + } + return s.String() + }() + } +} + func fieldName(x ast.Expr) *ast.Ident { switch t := x.(type) { case *ast.Ident: diff --git a/internal/arcgen/lang/util/camel_case.go b/internal/arcgen/lang/util/camel_case.go new file mode 100644 index 0000000..6edd8b7 --- /dev/null +++ b/internal/arcgen/lang/util/camel_case.go @@ -0,0 +1,12 @@ +package util + +func PascalCaseToCamelCase(s string) string { + if len(s) == 0 { + return s + } + if s[0] >= 'A' && s[0] <= 'Z' { + return string(s[0]+'a'-'A') + s[1:] + } + + return s +} diff --git a/internal/config/config.go b/internal/config/config.go index b88380a..e08661c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,11 +19,13 @@ type config struct { Version bool `json:"version"` Trace bool `json:"trace"` Debug bool `json:"debug"` + Dialect string `json:"dialect"` Language string `json:"language"` // Golang GoColumnTag string `json:"go_column_tag"` GoCRUDPackagePath string `json:"go_crud_package_path"` GoCRUDPackageName string `json:"go_crud_package_name"` + GoCRUDTypeName string `json:"go_crud_type_name"` GoHasManyTag string `json:"go_has_many_tag"` GoHasOneTag string `json:"go_has_one_tag"` GoMethodNameTable string `json:"go_method_name_table"` @@ -78,6 +80,9 @@ const ( _OptionDebug = "debug" _EnvKeyDebug = "ARCGEN_DEBUG" + _OptionDialect = "dialect" + _EnvKeyDialect = "ARCGEN_DIALECT" + _OptionLanguage = "lang" _EnvKeyLanguage = "ARCGEN_LANGUAGE" @@ -92,6 +97,9 @@ const ( _OptionGoCRUDPackageName = "go-crud-package-name" _EnvKeyGoCRUDPackageName = "ARCGEN_GO_CRUD_PACKAGE_NAME" + _OptionGoCRUDTypeName = "go-crud-type-name" + _EnvKeyGoCRUDTypeName = "ARCGEN_GO_CRUD_TYPE_NAME" + _OptionGoHasManyTag = "go-has-many-tag" _EnvKeyGoHasManyTag = "ARCGEN_GO_HAS_MANY_TAG" @@ -138,6 +146,11 @@ func load(ctx context.Context) (cfg *config, remainingArgs []string, err error) Description: "debug mode", Default: cliz.Default(false), }, + &cliz.StringOption{ + Name: _OptionDialect, Environment: _EnvKeyDialect, + Description: "dialect for DML", + Default: cliz.Default("postgres"), + }, &cliz.StringOption{ Name: _OptionLanguage, Environment: _EnvKeyLanguage, Description: "programming language to generate DDL", @@ -159,6 +172,11 @@ func load(ctx context.Context) (cfg *config, remainingArgs []string, err error) Description: "package name for CRUD", Default: cliz.Default(""), }, + &cliz.StringOption{ + Name: _OptionGoCRUDTypeName, Environment: _EnvKeyGoCRUDTypeName, + Description: "type name for CRUD", + Default: cliz.Default("CRUD"), + }, &cliz.StringOption{ Name: _OptionGoHasManyTag, Environment: _EnvKeyGoHasManyTag, Description: "\"hasMany\" annotation key for Go struct tag", @@ -206,11 +224,13 @@ func load(ctx context.Context) (cfg *config, remainingArgs []string, err error) Version: loadVersion(ctx, cmd), Trace: loadTrace(ctx, cmd), Debug: loadDebug(ctx, cmd), + Dialect: loadDialect(ctx, cmd), Language: loadLanguage(ctx, cmd), // Golang GoColumnTag: loadGoColumnTag(ctx, cmd), GoCRUDPackagePath: loadGoCRUDPackagePath(ctx, cmd), GoCRUDPackageName: loadGoCRUDPackageName(ctx, cmd), + GoCRUDTypeName: loadGoCRUDTypeName(ctx, cmd), GoHasManyTag: loadGoHasManyTag(ctx, cmd), GoHasOneTag: loadGoHasOneTag(ctx, cmd), GoMethodNameTable: loadGoMethodNameTable(ctx, cmd), diff --git a/internal/config/dialect.go b/internal/config/dialect.go new file mode 100644 index 0000000..4b49fa1 --- /dev/null +++ b/internal/config/dialect.go @@ -0,0 +1,18 @@ +package config + +import ( + "context" + + cliz "github.com/kunitsucom/util.go/exp/cli" +) + +func loadDialect(_ context.Context, cmd *cliz.Command) string { + v, _ := cmd.GetOptionString(_OptionDialect) + return v +} + +func Dialect() string { + globalConfigMu.RLock() + defer globalConfigMu.RUnlock() + return globalConfig.Dialect +} diff --git a/internal/config/go_crud_type_name.go b/internal/config/go_crud_type_name.go new file mode 100644 index 0000000..424c8a5 --- /dev/null +++ b/internal/config/go_crud_type_name.go @@ -0,0 +1,23 @@ +package config + +import ( + "context" + + cliz "github.com/kunitsucom/util.go/exp/cli" +) + +func loadGoCRUDTypeName(_ context.Context, cmd *cliz.Command) string { + v, _ := cmd.GetOptionString(_OptionGoCRUDTypeName) + return v +} + +func GoCRUDTypeName() string { + globalConfigMu.Lock() + defer globalConfigMu.Unlock() + + return globalConfig.GoCRUDTypeName +} + +func GoCRUDTypeNameUnexported() string { + return "_" + GoCRUDTypeName() +}