Skip to content

Commit

Permalink
fix: Add quote to generated query (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
ginokent authored Aug 13, 2024
2 parents b2e881e + a0668a3 commit 182a62a
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 29 deletions.
14 changes: 14 additions & 0 deletions internal/arcgen/lang/go/consts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package arcgengo

const (
importName = "orm"
receiverName = "_orm"
queryerContextVarName = "dbtx"
queryerContextTypeName = "DBTX"
createFuncPrefix = "Create"
readOneFuncPrefix = "Get"
readManyFuncPrefix = "List"
updateFuncPrefix = "Update"
deleteFuncPrefix = "Delete"
quote = `"`
)
11 changes: 6 additions & 5 deletions internal/arcgen/lang/go/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"strconv"
"strings"

"github.com/kunitsucom/arcgen/internal/arcgen/lang/util"
"github.com/kunitsucom/arcgen/internal/config"
)

Expand Down Expand Up @@ -54,15 +55,15 @@ func whereColumnsPlaceholder(columns []string, op string, initialNumber int) str
switch config.Dialect() {
case "mysql", "sqlite3":
// column1 = ? AND column2 = ? AND column3 = ...
return strings.Join(columns, " = ? "+op+" ") + " = ?"
return util.JoinStringsWithQuote(columns, " = ? "+op+" ", quote) + " = ?"
case "postgres", "cockroach":
// column1 = $1 AND column2 = $2 AND column3 = ...
var s strings.Builder
for i, column := range columns {
if i > 0 {
s.WriteString(" " + op + " ")
}
s.WriteString(column)
s.WriteString(util.QuoteString(column, quote))
s.WriteString(" = $")
s.WriteString(strconv.Itoa(i + initialNumber))
}
Expand All @@ -74,7 +75,7 @@ func whereColumnsPlaceholder(columns []string, op string, initialNumber int) str
if i > 0 {
s.WriteString(" " + op + " ")
}
s.WriteString(column)
s.WriteString(util.QuoteString(column, quote))
s.WriteString(" = @")
s.WriteString(column)
}
Expand All @@ -86,13 +87,13 @@ func whereColumnsPlaceholder(columns []string, op string, initialNumber int) str
if i > 0 {
s.WriteString(" " + op + " ")
}
s.WriteString(column)
s.WriteString(util.QuoteString(column, quote))
s.WriteString(" = :")
s.WriteString(column)
}
return s.String()
default:
// column1 = ? AND column2 = ? AND column3 = ...
return strings.Join(columns, " = ? "+op+" ") + " = ?"
return util.JoinStringsWithQuote(columns, " = ? "+op+" ", quote) + " = ?"
}
}
2 changes: 1 addition & 1 deletion internal/arcgen/lang/go/generate_orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func generateORMFileContent(buf buffer, arcSrcSet *ARCSourceSet) (string, error)
structPackageImportPath := config.GoORMStructPackageImportPath()
if structPackageImportPath == "" {
var err error
structPackageImportPath, err = util.GetPackageImportPath(filepath.Dir(arcSrcSet.Filename))
structPackageImportPath, err = util.DetectPackageImportPath(filepath.Dir(arcSrcSet.Filename))
if err != nil {
return "", errorz.Errorf("GetPackagePath: %w", err)
}
Expand Down
11 changes: 1 addition & 10 deletions internal/arcgen/lang/go/generate_orm_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@ import (
"github.com/kunitsucom/arcgen/internal/config"
)

const (
importName = "orm"
receiverName = "_orm"
queryerContextVarName = "dbtx"
queryerContextTypeName = "DBTX"
readOneFuncPrefix = "Get"
readManyFuncPrefix = "List"
)

func fprintORMCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice, ormFiles []string) error {
content, err := generateORMCommonFileContent(buf, arcSrcSetSlice, ormFiles)
if err != nil {
Expand Down Expand Up @@ -55,7 +46,7 @@ func generateORMCommonFileContent(buf buffer, arcSrcSetSlice ARCSourceSetSlice,
structPackageImportPath := config.GoORMStructPackageImportPath()
if structPackageImportPath == "" {
var err error
structPackageImportPath, err = util.GetPackageImportPath(filepath.Dir(arcSrcSetSlice[0].Filename))
structPackageImportPath, err = util.DetectPackageImportPath(filepath.Dir(arcSrcSetSlice[0].Filename))
if err != nil {
return "", errorz.Errorf("GetPackagePath: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/arcgen/lang/go/generate_orm_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"go/ast"
"go/token"
"strconv"
"strings"

"github.com/kunitsucom/arcgen/internal/arcgen/lang/util"
"github.com/kunitsucom/arcgen/internal/config"
)

Expand All @@ -27,7 +27,7 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
// }
// return nil
// }
funcName := "Create" + structName
funcName := createFuncPrefix + structName
queryName := funcName + "Query"
astFile.Decls = append(astFile.Decls,
&ast.GenDecl{
Expand All @@ -37,7 +37,7 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
Names: []*ast.Ident{{Name: queryName}},
Values: []ast.Expr{&ast.BasicLit{
Kind: token.STRING,
Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (" + columnValuesPlaceholder(columnNames, 1) + ")`",
Value: "`INSERT INTO " + tableName + " (" + util.JoinStringsWithQuote(columnNames, ", ", `"`) + ") VALUES (" + columnValuesPlaceholder(columnNames, 1) + ")`",
}},
},
},
Expand Down
4 changes: 2 additions & 2 deletions internal/arcgen/lang/go/generate_orm_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
// }
// return nil
// }
funcName := "Delete" + structName + "ByPK"
funcName := deleteFuncPrefix + structName + "ByPK"
queryName := funcName + "Query"
pkColumns := tableInfo.Columns.PrimaryKeys()
pkColumnNames := func() (pkColumnNames []string) {
Expand Down Expand Up @@ -140,7 +140,7 @@ func generateDELETEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
// }
// return nil
// }
byHasManyTagFuncName := "Delete" + structName + "By" + hasManyTag
byHasManyTagFuncName := deleteFuncPrefix + structName + "By" + hasManyTag
byHasManyTagQueryName := byHasManyTagFuncName + "Query"
hasManyColumns := hasManyColumnsByTag[hasManyTag]
astFile.Decls = append(astFile.Decls,
Expand Down
7 changes: 3 additions & 4 deletions internal/arcgen/lang/go/generate_orm_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"go/ast"
"go/token"
"strconv"
"strings"

"github.com/kunitsucom/arcgen/internal/arcgen/lang/util"
"github.com/kunitsucom/arcgen/internal/config"
Expand Down Expand Up @@ -43,7 +42,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
Names: []*ast.Ident{{Name: byPKQueryName}},
Values: []ast.Expr{&ast.BasicLit{
Kind: token.STRING,
Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(pks.ColumnNames(), "AND", 1) + "`",
Value: "`SELECT " + util.JoinStringsWithQuote(columnNames, ", ", `"`) + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(pks.ColumnNames(), "AND", 1) + "`",
}},
},
},
Expand Down Expand Up @@ -192,7 +191,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
Names: []*ast.Ident{{Name: byHasOneTagQueryName}},
Values: []ast.Expr{&ast.BasicLit{
Kind: token.STRING,
Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasOneColumns.ColumnNames(), "AND", 1) + "`",
Value: "`SELECT " + util.JoinStringsWithQuote(columnNames, ", ", `"`) + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasOneColumns.ColumnNames(), "AND", 1) + "`",
}},
},
},
Expand Down Expand Up @@ -359,7 +358,7 @@ func generateREADContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
Names: []*ast.Ident{{Name: byHasOneTagQueryName}},
Values: []ast.Expr{&ast.BasicLit{
Kind: token.STRING,
Value: "`SELECT " + strings.Join(columnNames, ", ") + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasManyColumns.ColumnNames(), "AND", 1) + "`",
Value: "`SELECT " + util.JoinStringsWithQuote(columnNames, ", ", `"`) + " FROM " + tableName + " WHERE " + whereColumnsPlaceholder(hasManyColumns.ColumnNames(), "AND", 1) + "`",
}},
},
},
Expand Down
6 changes: 3 additions & 3 deletions internal/arcgen/lang/go/generate_orm_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"go/ast"
"go/token"
"strconv"
"strings"

"github.com/kunitsucom/arcgen/internal/arcgen/lang/util"
"github.com/kunitsucom/arcgen/internal/config"
)

Expand All @@ -25,7 +25,7 @@ func generateUPDATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
// }
// return nil
// }
funcName := "Update" + structName
funcName := updateFuncPrefix + structName
queryName := funcName + "Query"
pkColumns := tableInfo.Columns.PrimaryKeys()
nonPKColumns := tableInfo.Columns.NonPrimaryKeys()
Expand All @@ -43,7 +43,7 @@ func generateUPDATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
Names: []*ast.Ident{{Name: queryName}},
Values: []ast.Expr{&ast.BasicLit{
Kind: token.STRING,
Value: "`UPDATE " + tableName + " SET (" + strings.Join(nonPKColumnNames, ", ") + ") = (" + columnValuesPlaceholder(nonPKColumnNames, 1) + ") WHERE " + whereColumnsPlaceholder(pkColumns.ColumnNames(), "AND", len(nonPKColumnNames)+1) + "`",
Value: "`UPDATE " + tableName + " SET (" + util.JoinStringsWithQuote(nonPKColumnNames, ", ", `"`) + ") = (" + columnValuesPlaceholder(nonPKColumnNames, 1) + ") WHERE " + whereColumnsPlaceholder(pkColumns.ColumnNames(), "AND", len(nonPKColumnNames)+1) + "`",
}},
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
apperr "github.com/kunitsucom/arcgen/pkg/errors"
)

func GetPackageImportPath(path string) (string, error) {
func DetectPackageImportPath(path string) (string, error) {
absDir, err := filepath.Abs(path)
if err != nil {
return "", fmt.Errorf("filepath.Abs: path=%s %w", path, err)
Expand Down
21 changes: 21 additions & 0 deletions internal/arcgen/lang/util/quote.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package util

import (
"strings"
)

func QuoteString(s string, quote string) string {
return quote + s + quote
}

func JoinStringsWithQuote(ss []string, sep string, quote string) string {
if len(ss) == 0 {
return ""
}

if len(ss) == 1 {
return QuoteString(ss[0], quote)
}

return quote + strings.Join(ss, quote+sep+quote) + quote
}

0 comments on commit 182a62a

Please sign in to comment.