From a0668a3749689bb0c3d6f66a891393843d316376 Mon Sep 17 00:00:00 2001 From: ginokent <29125616+ginokent@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:13:26 +0900 Subject: [PATCH] fix: Add quote to generated query --- internal/arcgen/lang/go/consts.go | 14 +++++++++++++ internal/arcgen/lang/go/dialect.go | 11 +++++----- internal/arcgen/lang/go/generate_orm.go | 2 +- .../arcgen/lang/go/generate_orm_common.go | 11 +--------- .../arcgen/lang/go/generate_orm_create.go | 6 +++--- .../arcgen/lang/go/generate_orm_delete.go | 4 ++-- internal/arcgen/lang/go/generate_orm_read.go | 7 +++---- .../arcgen/lang/go/generate_orm_update.go | 6 +++--- ...package_name.go => package_import_name.go} | 2 +- internal/arcgen/lang/util/quote.go | 21 +++++++++++++++++++ 10 files changed, 55 insertions(+), 29 deletions(-) create mode 100644 internal/arcgen/lang/go/consts.go rename internal/arcgen/lang/util/{package_name.go => package_import_name.go} (91%) create mode 100644 internal/arcgen/lang/util/quote.go diff --git a/internal/arcgen/lang/go/consts.go b/internal/arcgen/lang/go/consts.go new file mode 100644 index 0000000..dbde56a --- /dev/null +++ b/internal/arcgen/lang/go/consts.go @@ -0,0 +1,14 @@ +package arcgengo + +const ( + importName = "orm" + receiverName = "_orm" + queryerContextVarName = "dbtx" + queryerContextTypeName = "DBTX" + createFuncPrefix = "Create" + readOneFuncPrefix = "Get" + readManyFuncPrefix = "List" + updateFuncPrefix = "Update" + deleteFuncPrefix = "Delete" + quote = `"` +) diff --git a/internal/arcgen/lang/go/dialect.go b/internal/arcgen/lang/go/dialect.go index cd19c63..4a3ba94 100644 --- a/internal/arcgen/lang/go/dialect.go +++ b/internal/arcgen/lang/go/dialect.go @@ -4,6 +4,7 @@ import ( "strconv" "strings" + "github.com/kunitsucom/arcgen/internal/arcgen/lang/util" "github.com/kunitsucom/arcgen/internal/config" ) @@ -54,7 +55,7 @@ func whereColumnsPlaceholder(columns []string, op string, initialNumber int) str switch config.Dialect() { case "mysql", "sqlite3": // column1 = ? AND column2 = ? AND column3 = ... - return strings.Join(columns, " = ? "+op+" ") + " = ?" + return util.JoinStringsWithQuote(columns, " = ? "+op+" ", quote) + " = ?" case "postgres", "cockroach": // column1 = $1 AND column2 = $2 AND column3 = ... var s strings.Builder @@ -62,7 +63,7 @@ func whereColumnsPlaceholder(columns []string, op string, initialNumber int) str if i > 0 { s.WriteString(" " + op + " ") } - s.WriteString(column) + s.WriteString(util.QuoteString(column, quote)) s.WriteString(" = $") s.WriteString(strconv.Itoa(i + initialNumber)) } @@ -74,7 +75,7 @@ func whereColumnsPlaceholder(columns []string, op string, initialNumber int) str if i > 0 { s.WriteString(" " + op + " ") } - s.WriteString(column) + s.WriteString(util.QuoteString(column, quote)) s.WriteString(" = @") s.WriteString(column) } @@ -86,13 +87,13 @@ func whereColumnsPlaceholder(columns []string, op string, initialNumber int) str if i > 0 { s.WriteString(" " + op + " ") } - s.WriteString(column) + s.WriteString(util.QuoteString(column, quote)) s.WriteString(" = :") s.WriteString(column) } return s.String() default: // column1 = ? AND column2 = ? AND column3 = ... - return strings.Join(columns, " = ? "+op+" ") + " = ?" + return util.JoinStringsWithQuote(columns, " = ? "+op+" ", quote) + " = ?" } } diff --git a/internal/arcgen/lang/go/generate_orm.go b/internal/arcgen/lang/go/generate_orm.go index 65f0982..7ce94ff 100644 --- a/internal/arcgen/lang/go/generate_orm.go +++ b/internal/arcgen/lang/go/generate_orm.go @@ -47,7 +47,7 @@ func generateORMFileContent(buf buffer, arcSrcSet *ARCSourceSet) (string, error) structPackageImportPath := config.GoORMStructPackageImportPath() if structPackageImportPath == "" { var err error - structPackageImportPath, err = util.GetPackageImportPath(filepath.Dir(arcSrcSet.Filename)) + structPackageImportPath, err = util.DetectPackageImportPath(filepath.Dir(arcSrcSet.Filename)) if err != nil { return "", errorz.Errorf("GetPackagePath: %w", err) } diff --git a/internal/arcgen/lang/go/generate_orm_common.go b/internal/arcgen/lang/go/generate_orm_common.go index 39f242b..b3553a5 100644 --- a/internal/arcgen/lang/go/generate_orm_common.go +++ b/internal/arcgen/lang/go/generate_orm_common.go @@ -16,15 +16,6 @@ import ( "github.com/kunitsucom/arcgen/internal/config" ) -const ( - importName = "orm" - receiverName = "_orm" - queryerContextVarName = "dbtx" - queryerContextTypeName = "DBTX" - readOneFuncPrefix = "Get" - readManyFuncPrefix = "List" -) - func fprintORMCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice, ormFiles []string) error { content, err := generateORMCommonFileContent(buf, arcSrcSetSlice, ormFiles) if err != nil { @@ -55,7 +46,7 @@ func generateORMCommonFileContent(buf buffer, arcSrcSetSlice ARCSourceSetSlice, structPackageImportPath := config.GoORMStructPackageImportPath() if structPackageImportPath == "" { var err error - structPackageImportPath, err = util.GetPackageImportPath(filepath.Dir(arcSrcSetSlice[0].Filename)) + structPackageImportPath, err = util.DetectPackageImportPath(filepath.Dir(arcSrcSetSlice[0].Filename)) if err != nil { return "", errorz.Errorf("GetPackagePath: %w", err) } diff --git a/internal/arcgen/lang/go/generate_orm_create.go b/internal/arcgen/lang/go/generate_orm_create.go index 2771fd0..3958363 100644 --- a/internal/arcgen/lang/go/generate_orm_create.go +++ b/internal/arcgen/lang/go/generate_orm_create.go @@ -4,8 +4,8 @@ import ( "go/ast" "go/token" "strconv" - "strings" + "github.com/kunitsucom/arcgen/internal/arcgen/lang/util" "github.com/kunitsucom/arcgen/internal/config" ) @@ -27,7 +27,7 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // } // return nil // } - funcName := "Create" + structName + funcName := createFuncPrefix + structName queryName := funcName + "Query" astFile.Decls = append(astFile.Decls, &ast.GenDecl{ @@ -37,7 +37,7 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Names: []*ast.Ident{{Name: queryName}}, Values: []ast.Expr{&ast.BasicLit{ Kind: token.STRING, - Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (" + columnValuesPlaceholder(columnNames, 1) + ")`", + Value: "`INSERT INTO " + tableName + " (" + util.JoinStringsWithQuote(columnNames, ", ", `"`) + ") VALUES (" + columnValuesPlaceholder(columnNames, 1) + ")`", }}, }, }, diff --git a/internal/arcgen/lang/go/generate_orm_delete.go b/internal/arcgen/lang/go/generate_orm_delete.go index 0ff19bc..f6d0763 100644 --- a/internal/arcgen/lang/go/generate_orm_delete.go +++ b/internal/arcgen/lang/go/generate_orm_delete.go @@ -25,7 +25,7 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // } // return nil // } - funcName := "Delete" + structName + "ByPK" + funcName := deleteFuncPrefix + structName + "ByPK" queryName := funcName + "Query" pkColumns := tableInfo.Columns.PrimaryKeys() pkColumnNames := func() (pkColumnNames []string) { @@ -140,7 +140,7 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // } // return nil // } - byHasManyTagFuncName := "Delete" + structName + "By" + hasManyTag + byHasManyTagFuncName := deleteFuncPrefix + structName + "By" + hasManyTag byHasManyTagQueryName := byHasManyTagFuncName + "Query" hasManyColumns := hasManyColumnsByTag[hasManyTag] astFile.Decls = append(astFile.Decls, diff --git a/internal/arcgen/lang/go/generate_orm_read.go b/internal/arcgen/lang/go/generate_orm_read.go index 7f894b7..2c24852 100644 --- a/internal/arcgen/lang/go/generate_orm_read.go +++ b/internal/arcgen/lang/go/generate_orm_read.go @@ -4,7 +4,6 @@ import ( "go/ast" "go/token" "strconv" - "strings" "github.com/kunitsucom/arcgen/internal/arcgen/lang/util" "github.com/kunitsucom/arcgen/internal/config" @@ -43,7 +42,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Names: []*ast.Ident{{Name: byPKQueryName}}, Values: []ast.Expr{&ast.BasicLit{ Kind: token.STRING, - Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(pks.ColumnNames(), "AND", 1) + "`", + Value: "`SELECT " + util.JoinStringsWithQuote(columnNames, ", ", `"`) + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(pks.ColumnNames(), "AND", 1) + "`", }}, }, }, @@ -192,7 +191,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Names: []*ast.Ident{{Name: byHasOneTagQueryName}}, Values: []ast.Expr{&ast.BasicLit{ Kind: token.STRING, - Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasOneColumns.ColumnNames(), "AND", 1) + "`", + Value: "`SELECT " + util.JoinStringsWithQuote(columnNames, ", ", `"`) + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasOneColumns.ColumnNames(), "AND", 1) + "`", }}, }, }, @@ -359,7 +358,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Names: []*ast.Ident{{Name: byHasOneTagQueryName}}, Values: []ast.Expr{&ast.BasicLit{ Kind: token.STRING, - Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasManyColumns.ColumnNames(), "AND", 1) + "`", + Value: "`SELECT " + util.JoinStringsWithQuote(columnNames, ", ", `"`) + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasManyColumns.ColumnNames(), "AND", 1) + "`", }}, }, }, diff --git a/internal/arcgen/lang/go/generate_orm_update.go b/internal/arcgen/lang/go/generate_orm_update.go index e8393b2..60e2eab 100644 --- a/internal/arcgen/lang/go/generate_orm_update.go +++ b/internal/arcgen/lang/go/generate_orm_update.go @@ -4,8 +4,8 @@ import ( "go/ast" "go/token" "strconv" - "strings" + "github.com/kunitsucom/arcgen/internal/arcgen/lang/util" "github.com/kunitsucom/arcgen/internal/config" ) @@ -25,7 +25,7 @@ func generateUPDATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { // } // return nil // } - funcName := "Update" + structName + funcName := updateFuncPrefix + structName queryName := funcName + "Query" pkColumns := tableInfo.Columns.PrimaryKeys() nonPKColumns := tableInfo.Columns.NonPrimaryKeys() @@ -43,7 +43,7 @@ func generateUPDATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) { Names: []*ast.Ident{{Name: queryName}}, Values: []ast.Expr{&ast.BasicLit{ Kind: token.STRING, - Value: "`UPDATE " + tableName + " SET (" + strings.Join(nonPKColumnNames, ", ") + ") = (" + columnValuesPlaceholder(nonPKColumnNames, 1) + ") WHERE " + whereColumnsPlaceholder(pkColumns.ColumnNames(), "AND", len(nonPKColumnNames)+1) + "`", + Value: "`UPDATE " + tableName + " SET (" + util.JoinStringsWithQuote(nonPKColumnNames, ", ", `"`) + ") = (" + columnValuesPlaceholder(nonPKColumnNames, 1) + ") WHERE " + whereColumnsPlaceholder(pkColumns.ColumnNames(), "AND", len(nonPKColumnNames)+1) + "`", }}, }, }, diff --git a/internal/arcgen/lang/util/package_name.go b/internal/arcgen/lang/util/package_import_name.go similarity index 91% rename from internal/arcgen/lang/util/package_name.go rename to internal/arcgen/lang/util/package_import_name.go index 118ee1e..abc11d9 100644 --- a/internal/arcgen/lang/util/package_name.go +++ b/internal/arcgen/lang/util/package_import_name.go @@ -8,7 +8,7 @@ import ( apperr "github.com/kunitsucom/arcgen/pkg/errors" ) -func GetPackageImportPath(path string) (string, error) { +func DetectPackageImportPath(path string) (string, error) { absDir, err := filepath.Abs(path) if err != nil { return "", fmt.Errorf("filepath.Abs: path=%s %w", path, err) diff --git a/internal/arcgen/lang/util/quote.go b/internal/arcgen/lang/util/quote.go new file mode 100644 index 0000000..e1fc903 --- /dev/null +++ b/internal/arcgen/lang/util/quote.go @@ -0,0 +1,21 @@ +package util + +import ( + "strings" +) + +func QuoteString(s string, quote string) string { + return quote + s + quote +} + +func JoinStringsWithQuote(ss []string, sep string, quote string) string { + if len(ss) == 0 { + return "" + } + + if len(ss) == 1 { + return QuoteString(ss[0], quote) + } + + return quote + strings.Join(ss, quote+sep+quote) + quote +}